From 402e16bd2bcfb1491db34ccadc1fe11e4a389c7f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=B4me=20Leclercq?= Date: Tue, 8 Feb 2022 17:03:34 +0100 Subject: [PATCH] Shader: Handle type as expressions --- include/Nazara/Graphics/UberShader.hpp | 1 - include/Nazara/Shader/Ast/AstCloner.hpp | 7 +- include/Nazara/Shader/Ast/AstCloner.inl | 19 +- include/Nazara/Shader/Ast/AstCompare.hpp | 3 +- include/Nazara/Shader/Ast/AstCompare.inl | 10 +- .../Shader/Ast/AstExpressionVisitorExcept.hpp | 2 +- include/Nazara/Shader/Ast/AstNodeList.hpp | 1 + include/Nazara/Shader/Ast/AstOptimizer.hpp | 2 + .../Nazara/Shader/Ast/AstRecursiveVisitor.hpp | 1 + include/Nazara/Shader/Ast/AstReflect.hpp | 2 +- include/Nazara/Shader/Ast/AstSerializer.hpp | 3 +- include/Nazara/Shader/Ast/AstSerializer.inl | 69 +- .../Shader/Ast/AstStatementVisitorExcept.hpp | 2 +- include/Nazara/Shader/Ast/AstTypes.hpp | 36 + include/Nazara/Shader/Ast/Attribute.hpp | 20 +- include/Nazara/Shader/Ast/Attribute.inl | 16 +- include/Nazara/Shader/Ast/ExpressionType.hpp | 62 +- include/Nazara/Shader/Ast/ExpressionType.inl | 62 +- include/Nazara/Shader/Ast/Nodes.hpp | 42 +- include/Nazara/Shader/Ast/SanitizeVisitor.hpp | 38 +- include/Nazara/Shader/GlslWriter.hpp | 12 +- include/Nazara/Shader/LangWriter.hpp | 10 +- include/Nazara/Shader/ShaderBuilder.hpp | 23 +- include/Nazara/Shader/ShaderBuilder.inl | 35 +- include/Nazara/Shader/ShaderLangParser.hpp | 24 +- include/Nazara/Shader/SpirvAstVisitor.hpp | 7 +- include/Nazara/Shader/SpirvExpressionLoad.hpp | 24 +- .../Nazara/Shader/SpirvExpressionStore.hpp | 4 +- src/Nazara/Graphics/BasicMaterial.cpp | 4 +- src/Nazara/Graphics/PhongLightingMaterial.cpp | 4 +- .../Resources/Shaders/basic_material.nzsl | 2 +- .../Resources/Shaders/phong_material.nzsl | 2 +- src/Nazara/Graphics/UberShader.cpp | 5 +- src/Nazara/Shader/Ast/AstCloner.cpp | 47 +- .../Shader/Ast/AstExpressionVisitorExcept.cpp | 2 +- src/Nazara/Shader/Ast/AstOptimizer.cpp | 25 +- src/Nazara/Shader/Ast/AstRecursiveVisitor.cpp | 6 + src/Nazara/Shader/Ast/AstSerializer.cpp | 182 ++- .../Shader/Ast/AstStatementVisitorExcept.cpp | 2 +- src/Nazara/Shader/Ast/ExpressionType.cpp | 33 +- src/Nazara/Shader/Ast/SanitizeVisitor.cpp | 1193 +++++++++++------ src/Nazara/Shader/GlslWriter.cpp | 163 ++- src/Nazara/Shader/LangWriter.cpp | 121 +- src/Nazara/Shader/ShaderLangParser.cpp | 329 +---- src/Nazara/Shader/SpirvAstVisitor.cpp | 14 +- src/Nazara/Shader/SpirvConstantCache.cpp | 15 +- src/Nazara/Shader/SpirvExpressionLoad.cpp | 95 +- src/Nazara/Shader/SpirvExpressionStore.cpp | 17 +- src/Nazara/Shader/SpirvWriter.cpp | 71 +- tests/Engine/Shader/AccessMemberTest.cpp | 4 +- tests/Engine/Shader/Loops.cpp | 1 - tests/Engine/Shader/Optimizations.cpp | 1 + tests/Engine/Shader/Swizzle.cpp | 12 +- 53 files changed, 1746 insertions(+), 1141 deletions(-) create mode 100644 include/Nazara/Shader/Ast/AstTypes.hpp diff --git a/include/Nazara/Graphics/UberShader.hpp b/include/Nazara/Graphics/UberShader.hpp index c285e63bd..77d7219ea 100644 --- a/include/Nazara/Graphics/UberShader.hpp +++ b/include/Nazara/Graphics/UberShader.hpp @@ -58,7 +58,6 @@ namespace Nz struct Option { std::size_t index; - ShaderAst::ExpressionType type; }; private: diff --git a/include/Nazara/Shader/Ast/AstCloner.hpp b/include/Nazara/Shader/Ast/AstCloner.hpp index b24c97855..1cf051d72 100644 --- a/include/Nazara/Shader/Ast/AstCloner.hpp +++ b/include/Nazara/Shader/Ast/AstCloner.hpp @@ -24,7 +24,8 @@ namespace Nz::ShaderAst AstCloner(AstCloner&&) = delete; ~AstCloner() = default; - template AttributeValue Clone(const AttributeValue& attribute); + template ExpressionValue Clone(const ExpressionValue& expressionValue); + inline ExpressionValue Clone(const ExpressionValue& expressionValue); ExpressionPtr Clone(Expression& statement); StatementPtr Clone(Statement& statement); @@ -37,6 +38,7 @@ namespace Nz::ShaderAst virtual ExpressionPtr CloneExpression(Expression& expr); virtual StatementPtr CloneStatement(Statement& statement); + virtual ExpressionValue CloneType(const ExpressionValue& exprType); virtual ExpressionPtr Clone(AccessIdentifierExpression& node); virtual ExpressionPtr Clone(AccessIndexExpression& node); @@ -69,6 +71,7 @@ namespace Nz::ShaderAst virtual StatementPtr Clone(MultiStatement& node); virtual StatementPtr Clone(NoOpStatement& node); virtual StatementPtr Clone(ReturnStatement& node); + virtual StatementPtr Clone(ScopedStatement& node); virtual StatementPtr Clone(WhileStatement& node); #define NAZARA_SHADERAST_NODE(NodeType) void Visit(NodeType& node) override; @@ -85,7 +88,7 @@ namespace Nz::ShaderAst std::vector m_statementStack; }; - template AttributeValue Clone(const AttributeValue& attribute); + template ExpressionValue Clone(const ExpressionValue& attribute); inline ExpressionPtr Clone(Expression& node); inline StatementPtr Clone(Statement& node); } diff --git a/include/Nazara/Shader/Ast/AstCloner.inl b/include/Nazara/Shader/Ast/AstCloner.inl index dc5158d15..29bc685a3 100644 --- a/include/Nazara/Shader/Ast/AstCloner.inl +++ b/include/Nazara/Shader/Ast/AstCloner.inl @@ -8,20 +8,25 @@ namespace Nz::ShaderAst { template - AttributeValue AstCloner::Clone(const AttributeValue& attribute) + ExpressionValue AstCloner::Clone(const ExpressionValue& expressionValue) { - if (!attribute.HasValue()) + if (!expressionValue.HasValue()) return {}; - if (attribute.IsExpression()) - return CloneExpression(attribute.GetExpression()); + if (expressionValue.IsExpression()) + return CloneExpression(expressionValue.GetExpression()); else { - assert(attribute.IsResultingValue()); - return attribute.GetResultingValue(); + assert(expressionValue.IsResultingValue()); + return expressionValue.GetResultingValue(); } } + inline ExpressionValue AstCloner::Clone(const ExpressionValue& expressionValue) + { + return CloneType(expressionValue); + } + ExpressionPtr AstCloner::CloneExpression(const ExpressionPtr& expr) { if (!expr) @@ -40,7 +45,7 @@ namespace Nz::ShaderAst template - AttributeValue Clone(const AttributeValue& attribute) + ExpressionValue Clone(const ExpressionValue& attribute) { AstCloner cloner; return cloner.Clone(attribute); diff --git a/include/Nazara/Shader/Ast/AstCompare.hpp b/include/Nazara/Shader/Ast/AstCompare.hpp index 0683a9a7b..a88a2dbbe 100644 --- a/include/Nazara/Shader/Ast/AstCompare.hpp +++ b/include/Nazara/Shader/Ast/AstCompare.hpp @@ -21,7 +21,7 @@ namespace Nz::ShaderAst template bool Compare(const T& lhs, const T& rhs); template bool Compare(const std::array& lhs, const std::array& rhs); template bool Compare(const std::vector& lhs, const std::vector& rhs); - template bool Compare(const AttributeValue& lhs, const AttributeValue& rhs); + template bool Compare(const ExpressionValue& lhs, const ExpressionValue& rhs); inline bool Compare(const BranchStatement::ConditionalStatement& lhs, const BranchStatement::ConditionalStatement& rhs); inline bool Compare(const DeclareExternalStatement::ExternalVar& lhs, const DeclareExternalStatement::ExternalVar& rhs); inline bool Compare(const DeclareFunctionStatement::Parameter& lhs, const DeclareFunctionStatement::Parameter& rhs); @@ -59,6 +59,7 @@ namespace Nz::ShaderAst inline bool Compare(const MultiStatement& lhs, const MultiStatement& rhs); inline bool Compare(const NoOpStatement& lhs, const NoOpStatement& rhs); inline bool Compare(const ReturnStatement& lhs, const ReturnStatement& rhs); + inline bool Compare(const ScopedStatement& lhs, const ScopedStatement& rhs); inline bool Compare(const WhileStatement& lhs, const WhileStatement& rhs); } diff --git a/include/Nazara/Shader/Ast/AstCompare.inl b/include/Nazara/Shader/Ast/AstCompare.inl index 7c10bb09d..94fcc742a 100644 --- a/include/Nazara/Shader/Ast/AstCompare.inl +++ b/include/Nazara/Shader/Ast/AstCompare.inl @@ -78,7 +78,7 @@ namespace Nz::ShaderAst } template - bool Compare(const AttributeValue& lhs, const AttributeValue& rhs) + bool Compare(const ExpressionValue& lhs, const ExpressionValue& rhs) { if (!Compare(lhs.HasValue(), rhs.HasValue())) return false; @@ -519,6 +519,14 @@ namespace Nz::ShaderAst return true; } + bool Compare(const ScopedStatement& lhs, const ScopedStatement& rhs) + { + if (!Compare(lhs.statement, rhs.statement)) + return false; + + return true; + } + inline bool Compare(const WhileStatement& lhs, const WhileStatement& rhs) { if (!Compare(lhs.unroll, rhs.unroll)) diff --git a/include/Nazara/Shader/Ast/AstExpressionVisitorExcept.hpp b/include/Nazara/Shader/Ast/AstExpressionVisitorExcept.hpp index 4e304dea0..568d1bb69 100644 --- a/include/Nazara/Shader/Ast/AstExpressionVisitorExcept.hpp +++ b/include/Nazara/Shader/Ast/AstExpressionVisitorExcept.hpp @@ -13,7 +13,7 @@ namespace Nz::ShaderAst { - class NAZARA_SHADER_API ExpressionVisitorExcept : public AstExpressionVisitor + class NAZARA_SHADER_API AstExpressionVisitorExcept : public AstExpressionVisitor { public: using AstExpressionVisitor::Visit; diff --git a/include/Nazara/Shader/Ast/AstNodeList.hpp b/include/Nazara/Shader/Ast/AstNodeList.hpp index d81e0d9ee..84e162ac3 100644 --- a/include/Nazara/Shader/Ast/AstNodeList.hpp +++ b/include/Nazara/Shader/Ast/AstNodeList.hpp @@ -58,6 +58,7 @@ NAZARA_SHADERAST_STATEMENT(ExpressionStatement) NAZARA_SHADERAST_STATEMENT(MultiStatement) NAZARA_SHADERAST_STATEMENT(NoOpStatement) NAZARA_SHADERAST_STATEMENT(ReturnStatement) +NAZARA_SHADERAST_STATEMENT(ScopedStatement) NAZARA_SHADERAST_STATEMENT_LAST(WhileStatement) #undef NAZARA_SHADERAST_EXPRESSION diff --git a/include/Nazara/Shader/Ast/AstOptimizer.hpp b/include/Nazara/Shader/Ast/AstOptimizer.hpp index 7d2537bf2..476fffc9b 100644 --- a/include/Nazara/Shader/Ast/AstOptimizer.hpp +++ b/include/Nazara/Shader/Ast/AstOptimizer.hpp @@ -57,6 +57,8 @@ namespace Nz::ShaderAst template ExpressionPtr PropagateVec3Cast(TargetType v1, TargetType v2, TargetType v3); template ExpressionPtr PropagateVec4Cast(TargetType v1, TargetType v2, TargetType v3, TargetType v4); + StatementPtr Unscope(StatementPtr node); + private: Options m_options; }; diff --git a/include/Nazara/Shader/Ast/AstRecursiveVisitor.hpp b/include/Nazara/Shader/Ast/AstRecursiveVisitor.hpp index 992d4233a..f09269bd9 100644 --- a/include/Nazara/Shader/Ast/AstRecursiveVisitor.hpp +++ b/include/Nazara/Shader/Ast/AstRecursiveVisitor.hpp @@ -51,6 +51,7 @@ namespace Nz::ShaderAst void Visit(MultiStatement& node) override; void Visit(NoOpStatement& node) override; void Visit(ReturnStatement& node) override; + void Visit(ScopedStatement& node) override; void Visit(WhileStatement& node) override; }; } diff --git a/include/Nazara/Shader/Ast/AstReflect.hpp b/include/Nazara/Shader/Ast/AstReflect.hpp index ca0cb3846..6dd569cbb 100644 --- a/include/Nazara/Shader/Ast/AstReflect.hpp +++ b/include/Nazara/Shader/Ast/AstReflect.hpp @@ -32,7 +32,7 @@ namespace Nz::ShaderAst struct Callbacks { std::function onEntryPointDeclaration; - std::function onOptionDeclaration; + std::function& optionType)> onOptionDeclaration; }; private: diff --git a/include/Nazara/Shader/Ast/AstSerializer.hpp b/include/Nazara/Shader/Ast/AstSerializer.hpp index c1beb3edc..4fda20db6 100644 --- a/include/Nazara/Shader/Ast/AstSerializer.hpp +++ b/include/Nazara/Shader/Ast/AstSerializer.hpp @@ -54,12 +54,13 @@ namespace Nz::ShaderAst void Serialize(MultiStatement& node); void Serialize(NoOpStatement& node); void Serialize(ReturnStatement& node); + void Serialize(ScopedStatement& node); void Serialize(WhileStatement& node); protected: - template void Attribute(AttributeValue& attribute); template void Container(T& container); template void Enum(T& enumVal); + template void ExprValue(ExpressionValue& attribute); template void OptEnum(std::optional& optVal); template void OptVal(std::optional& optVal); diff --git a/include/Nazara/Shader/Ast/AstSerializer.inl b/include/Nazara/Shader/Ast/AstSerializer.inl index 08aca16f2..1b6889878 100644 --- a/include/Nazara/Shader/Ast/AstSerializer.inl +++ b/include/Nazara/Shader/Ast/AstSerializer.inl @@ -3,12 +3,42 @@ // For conditions of distribution and use, see copyright notice in Config.hpp #include +#include #include namespace Nz::ShaderAst { template - void AstSerializerBase::Attribute(AttributeValue& attribute) + void AstSerializerBase::Container(T& container) + { + bool isWriting = IsWriting(); + + UInt32 size; + if (isWriting) + size = SafeCast(container.size()); + + Value(size); + if (!isWriting) + container.resize(size); + } + + + template + void AstSerializerBase::Enum(T& enumVal) + { + bool isWriting = IsWriting(); + + UInt32 value; + if (isWriting) + value = SafeCast(enumVal); + + Value(value); + if (!isWriting) + enumVal = static_cast(value); + } + + template + void AstSerializerBase::ExprValue(ExpressionValue& attribute) { UInt32 valueType; if (IsWriting()) @@ -55,6 +85,8 @@ namespace Nz::ShaderAst T value; if constexpr (std::is_enum_v) Enum(value); + else if constexpr (std::is_same_v) + Type(value); else Value(value); @@ -65,6 +97,8 @@ namespace Nz::ShaderAst T& value = const_cast(attribute.GetResultingValue()); //< not used for writing if constexpr (std::is_enum_v) Enum(value); + else if constexpr (std::is_same_v) + Type(value); else Value(value); } @@ -74,35 +108,6 @@ namespace Nz::ShaderAst } } - template - void AstSerializerBase::Container(T& container) - { - bool isWriting = IsWriting(); - - UInt32 size; - if (isWriting) - size = UInt32(container.size()); - - Value(size); - if (!isWriting) - container.resize(size); - } - - - template - void AstSerializerBase::Enum(T& enumVal) - { - bool isWriting = IsWriting(); - - UInt32 value; - if (isWriting) - value = static_cast(enumVal); - - Value(value); - if (!isWriting) - enumVal = static_cast(value); - } - template void AstSerializerBase::OptEnum(std::optional& optVal) { @@ -150,12 +155,12 @@ namespace Nz::ShaderAst UInt32 fixedVal; if (isWriting) - fixedVal = static_cast(val); + fixedVal = SafeCast(val); Value(fixedVal); if (!isWriting) - val = static_cast(fixedVal); + val = SafeCast(fixedVal); } inline ShaderAstSerializer::ShaderAstSerializer(ByteStream& stream) : diff --git a/include/Nazara/Shader/Ast/AstStatementVisitorExcept.hpp b/include/Nazara/Shader/Ast/AstStatementVisitorExcept.hpp index 2a194aa8f..ab10dffe7 100644 --- a/include/Nazara/Shader/Ast/AstStatementVisitorExcept.hpp +++ b/include/Nazara/Shader/Ast/AstStatementVisitorExcept.hpp @@ -13,7 +13,7 @@ namespace Nz::ShaderAst { - class NAZARA_SHADER_API StatementVisitorExcept : public AstStatementVisitor + class NAZARA_SHADER_API AstStatementVisitorExcept : public AstStatementVisitor { public: using AstStatementVisitor::Visit; diff --git a/include/Nazara/Shader/Ast/AstTypes.hpp b/include/Nazara/Shader/Ast/AstTypes.hpp new file mode 100644 index 000000000..0b12cc7ff --- /dev/null +++ b/include/Nazara/Shader/Ast/AstTypes.hpp @@ -0,0 +1,36 @@ +// Copyright (C) 2022 Jérôme "Lynix" Leclercq (lynix680@gmail.com) +// This file is part of the "Nazara Engine - Shader module" +// For conditions of distribution and use, see copyright notice in Config.hpp + +#pragma once + +#ifndef NAZARA_SHADER_AST_ASTTYPES_HPP +#define NAZARA_SHADER_AST_ASTTYPES_HPP + +#include +#include +#include + +namespace Nz::ShaderAst +{ + enum class TypeParameterCategory + { + ConstantValue, + FullType, + PrimitiveType, + StructType + }; + + struct PartialType; + + using TypeParameter = std::variant; + + struct PartialType + { + std::vector parameters; + std::function buildFunc; + }; + +} + +#endif // NAZARA_SHADER_AST_ASTTYPES_HPP diff --git a/include/Nazara/Shader/Ast/Attribute.hpp b/include/Nazara/Shader/Ast/Attribute.hpp index 669a881a7..ec3bf6667 100644 --- a/include/Nazara/Shader/Ast/Attribute.hpp +++ b/include/Nazara/Shader/Ast/Attribute.hpp @@ -20,15 +20,15 @@ namespace Nz::ShaderAst using ExpressionPtr = std::unique_ptr; template - class AttributeValue + class ExpressionValue { public: - AttributeValue() = default; - AttributeValue(T value); - AttributeValue(ExpressionPtr expr); - AttributeValue(const AttributeValue&) = default; - AttributeValue(AttributeValue&&) = default; - ~AttributeValue() = default; + ExpressionValue() = default; + ExpressionValue(T value); + ExpressionValue(ExpressionPtr expr); + ExpressionValue(const ExpressionValue&) = default; + ExpressionValue(ExpressionValue&&) noexcept = default; + ~ExpressionValue() = default; ExpressionPtr&& GetExpression() &&; const ExpressionPtr& GetExpression() const &; @@ -39,14 +39,14 @@ namespace Nz::ShaderAst bool HasValue() const; - AttributeValue& operator=(const AttributeValue&) = default; - AttributeValue& operator=(AttributeValue&&) = default; + ExpressionValue& operator=(const ExpressionValue&) = default; + ExpressionValue& operator=(ExpressionValue&&) noexcept = default; private: std::variant m_value; }; - struct Attribute + struct ExprValue { using Param = std::optional; diff --git a/include/Nazara/Shader/Ast/Attribute.inl b/include/Nazara/Shader/Ast/Attribute.inl index 0b94828cf..a7ea1e475 100644 --- a/include/Nazara/Shader/Ast/Attribute.inl +++ b/include/Nazara/Shader/Ast/Attribute.inl @@ -10,20 +10,20 @@ namespace Nz::ShaderAst { template - AttributeValue::AttributeValue(T value) : + ExpressionValue::ExpressionValue(T value) : m_value(std::move(value)) { } template - AttributeValue::AttributeValue(ExpressionPtr expr) + ExpressionValue::ExpressionValue(ExpressionPtr expr) { assert(expr); m_value = std::move(expr); } template - ExpressionPtr&& AttributeValue::GetExpression() && + ExpressionPtr&& ExpressionValue::GetExpression() && { if (!IsExpression()) throw std::runtime_error("excepted expression"); @@ -32,7 +32,7 @@ namespace Nz::ShaderAst } template - const ExpressionPtr& AttributeValue::GetExpression() const & + const ExpressionPtr& ExpressionValue::GetExpression() const & { if (!IsExpression()) throw std::runtime_error("excepted expression"); @@ -42,7 +42,7 @@ namespace Nz::ShaderAst } template - const T& AttributeValue::GetResultingValue() const + const T& ExpressionValue::GetResultingValue() const { if (!IsResultingValue()) throw std::runtime_error("excepted resulting value"); @@ -51,19 +51,19 @@ namespace Nz::ShaderAst } template - bool AttributeValue::IsExpression() const + bool ExpressionValue::IsExpression() const { return std::holds_alternative(m_value); } template - bool AttributeValue::IsResultingValue() const + bool ExpressionValue::IsResultingValue() const { return std::holds_alternative(m_value); } template - bool AttributeValue::HasValue() const + bool ExpressionValue::HasValue() const { return !std::holds_alternative(m_value); } diff --git a/include/Nazara/Shader/Ast/ExpressionType.hpp b/include/Nazara/Shader/Ast/ExpressionType.hpp index 2a5d44824..46fc3efbd 100644 --- a/include/Nazara/Shader/Ast/ExpressionType.hpp +++ b/include/Nazara/Shader/Ast/ExpressionType.hpp @@ -29,14 +29,22 @@ namespace Nz::ShaderAst ArrayType& operator=(const ArrayType& array); ArrayType& operator=(ArrayType&&) noexcept = default; - AttributeValue length; + UInt32 length; std::unique_ptr containedType; bool operator==(const ArrayType& rhs) const; inline bool operator!=(const ArrayType& rhs) const; }; - struct IdentifierType //< Alias or struct + struct FunctionType + { + std::size_t funcIndex; + + inline bool operator==(const FunctionType& rhs) const; + inline bool operator!=(const FunctionType& rhs) const; + }; + + struct IdentifierType { std::string name; @@ -44,6 +52,14 @@ namespace Nz::ShaderAst inline bool operator!=(const IdentifierType& rhs) const; }; + struct IntrinsicFunctionType + { + IntrinsicType intrinsic; + + inline bool operator==(const IntrinsicFunctionType& rhs) const; + inline bool operator!=(const IntrinsicFunctionType& rhs) const; + }; + struct MatrixType { std::size_t columnCount; @@ -54,6 +70,22 @@ namespace Nz::ShaderAst inline bool operator!=(const MatrixType& rhs) const; }; + struct NAZARA_SHADER_API MethodType + { + MethodType() = default; + MethodType(const MethodType& methodType); + MethodType(MethodType&&) noexcept = default; + + MethodType& operator=(const MethodType& methodType); + MethodType& operator=(MethodType&&) noexcept = default; + + std::unique_ptr objectType; + std::size_t methodIndex; + + bool operator==(const MethodType& rhs) const; + inline bool operator!=(const MethodType& rhs) const; + }; + struct NoType { inline bool operator==(const NoType& rhs) const; @@ -77,9 +109,17 @@ namespace Nz::ShaderAst inline bool operator!=(const StructType& rhs) const; }; + struct Type + { + std::size_t typeIndex; + + inline bool operator==(const Type& rhs) const; + inline bool operator!=(const Type& rhs) const; + }; + struct UniformType { - std::variant containedType; + StructType containedType; inline bool operator==(const UniformType& rhs) const; inline bool operator!=(const UniformType& rhs) const; @@ -94,7 +134,7 @@ namespace Nz::ShaderAst inline bool operator!=(const VectorType& rhs) const; }; - using ExpressionType = std::variant; + using ExpressionType = std::variant; struct ContainedType { @@ -105,25 +145,29 @@ namespace Nz::ShaderAst { struct StructMember { - AttributeValue builtin; - AttributeValue cond; - AttributeValue locationIndex; + ExpressionValue builtin; + ExpressionValue cond; + ExpressionValue locationIndex; + ExpressionValue type; std::string name; - ExpressionType type; }; - AttributeValue layout; + ExpressionValue layout; std::string name; std::vector members; }; inline bool IsArrayType(const ExpressionType& type); + inline bool IsFunctionType(const ExpressionType& type); inline bool IsIdentifierType(const ExpressionType& type); + inline bool IsIntrinsicFunctionType(const ExpressionType& type); inline bool IsMatrixType(const ExpressionType& type); + inline bool IsMethodType(const ExpressionType& type); inline bool IsNoType(const ExpressionType& type); inline bool IsPrimitiveType(const ExpressionType& type); inline bool IsSamplerType(const ExpressionType& type); inline bool IsStructType(const ExpressionType& type); + inline bool IsTypeExpression(const ExpressionType& type); inline bool IsUniformType(const ExpressionType& type); inline bool IsVectorType(const ExpressionType& type); } diff --git a/include/Nazara/Shader/Ast/ExpressionType.inl b/include/Nazara/Shader/Ast/ExpressionType.inl index 0fe8d0163..daefa9ec2 100644 --- a/include/Nazara/Shader/Ast/ExpressionType.inl +++ b/include/Nazara/Shader/Ast/ExpressionType.inl @@ -14,6 +14,17 @@ namespace Nz::ShaderAst } + inline bool FunctionType::operator==(const FunctionType& rhs) const + { + return funcIndex == rhs.funcIndex; + } + + inline bool FunctionType::operator!=(const FunctionType& rhs) const + { + return !operator==(rhs); + } + + inline bool IdentifierType::operator==(const IdentifierType& rhs) const { return name == rhs.name; @@ -25,6 +36,17 @@ namespace Nz::ShaderAst } + inline bool IntrinsicFunctionType::operator==(const IntrinsicFunctionType& rhs) const + { + return intrinsic == rhs.intrinsic; + } + + inline bool IntrinsicFunctionType::operator!=(const IntrinsicFunctionType& rhs) const + { + return !operator==(rhs); + } + + inline bool MatrixType::operator==(const MatrixType& rhs) const { return columnCount == rhs.columnCount && rowCount == rhs.rowCount && type == rhs.type; @@ -36,6 +58,12 @@ namespace Nz::ShaderAst } + inline bool MethodType::operator!=(const MethodType& rhs) const + { + return !operator==(rhs); + } + + inline bool NoType::operator==(const NoType& /*rhs*/) const { return true; @@ -68,6 +96,18 @@ namespace Nz::ShaderAst return !operator==(rhs); } + + inline bool Type::operator==(const Type& rhs) const + { + return typeIndex == rhs.typeIndex; + } + + inline bool Type::operator!=(const Type& rhs) const + { + return !operator==(rhs); + } + + inline bool UniformType::operator==(const UniformType& rhs) const { return containedType == rhs.containedType; @@ -90,21 +130,36 @@ namespace Nz::ShaderAst } - bool IsArrayType(const ExpressionType& type) + inline bool IsArrayType(const ExpressionType& type) { return std::holds_alternative(type); } + inline bool IsFunctionType(const ExpressionType& type) + { + return std::holds_alternative(type); + } + inline bool IsIdentifierType(const ExpressionType& type) { return std::holds_alternative(type); } + inline bool IsIntrinsicFunctionType(const ExpressionType& type) + { + return std::holds_alternative(type); + } + inline bool IsMatrixType(const ExpressionType& type) { return std::holds_alternative(type); } + inline bool IsMethodType(const ExpressionType& type) + { + return std::holds_alternative(type); + } + inline bool IsNoType(const ExpressionType& type) { return std::holds_alternative(type); @@ -125,6 +180,11 @@ namespace Nz::ShaderAst return std::holds_alternative(type); } + bool IsTypeExpression(const ExpressionType& type) + { + return std::holds_alternative(type); + } + bool IsUniformType(const ExpressionType& type) { return std::holds_alternative(type); diff --git a/include/Nazara/Shader/Ast/Nodes.hpp b/include/Nazara/Shader/Ast/Nodes.hpp index 2987cd980..ae7e73f1f 100644 --- a/include/Nazara/Shader/Ast/Nodes.hpp +++ b/include/Nazara/Shader/Ast/Nodes.hpp @@ -107,7 +107,7 @@ namespace Nz::ShaderAst NodeType GetType() const override; void Visit(AstExpressionVisitor& visitor) override; - std::variant targetFunction; + ExpressionPtr targetFunction; std::vector parameters; }; @@ -126,7 +126,7 @@ namespace Nz::ShaderAst NodeType GetType() const override; void Visit(AstExpressionVisitor& visitor) override; - ExpressionType targetType; + ExpressionValue targetType; std::array expressions; }; @@ -249,10 +249,10 @@ namespace Nz::ShaderAst NodeType GetType() const override; void Visit(AstStatementVisitor& visitor) override; + ExpressionValue type; std::optional constIndex; std::string name; ExpressionPtr expression; - ExpressionType type; }; struct NAZARA_SHADER_API DeclareExternalStatement : Statement @@ -262,13 +262,13 @@ namespace Nz::ShaderAst struct ExternalVar { - AttributeValue bindingIndex; - AttributeValue bindingSet; + ExpressionValue bindingIndex; + ExpressionValue bindingSet; + ExpressionValue type; std::string name; - ExpressionType type; }; - AttributeValue bindingSet; + ExpressionValue bindingSet; std::optional varIndex; std::vector externalVars; }; @@ -281,18 +281,18 @@ namespace Nz::ShaderAst struct Parameter { std::string name; - ExpressionType type; + ExpressionValue type; }; - AttributeValue depthWrite; - AttributeValue earlyFragmentTests; - AttributeValue entryStage; + ExpressionValue depthWrite; + ExpressionValue earlyFragmentTests; + ExpressionValue entryStage; + ExpressionValue returnType; std::optional funcIndex; std::optional varIndex; std::string name; std::vector parameters; std::vector statements; - ExpressionType returnType; }; struct NAZARA_SHADER_API DeclareOptionStatement : Statement @@ -303,7 +303,7 @@ namespace Nz::ShaderAst std::optional optIndex; std::string optName; ExpressionPtr defaultValue; - ExpressionType optType; + ExpressionValue optType; }; struct NAZARA_SHADER_API DeclareStructStatement : Statement @@ -323,7 +323,7 @@ namespace Nz::ShaderAst std::optional varIndex; std::string varName; ExpressionPtr initialExpression; - ExpressionType varType; + ExpressionValue varType; }; struct NAZARA_SHADER_API DiscardStatement : Statement @@ -345,7 +345,7 @@ namespace Nz::ShaderAst NodeType GetType() const override; void Visit(AstStatementVisitor& visitor) override; - AttributeValue unroll; + ExpressionValue unroll; std::optional varIndex; std::string varName; ExpressionPtr fromExpr; @@ -359,7 +359,7 @@ namespace Nz::ShaderAst NodeType GetType() const override; void Visit(AstStatementVisitor& visitor) override; - AttributeValue unroll; + ExpressionValue unroll; std::optional varIndex; std::string varName; ExpressionPtr expression; @@ -388,12 +388,20 @@ namespace Nz::ShaderAst ExpressionPtr returnExpr; }; + struct NAZARA_SHADER_API ScopedStatement : Statement + { + NodeType GetType() const override; + void Visit(AstStatementVisitor& visitor) override; + + StatementPtr statement; + }; + struct NAZARA_SHADER_API WhileStatement : Statement { NodeType GetType() const override; void Visit(AstStatementVisitor& visitor) override; - AttributeValue unroll; + ExpressionValue unroll; ExpressionPtr condition; StatementPtr body; }; diff --git a/include/Nazara/Shader/Ast/SanitizeVisitor.hpp b/include/Nazara/Shader/Ast/SanitizeVisitor.hpp index 4d878668b..caec552f6 100644 --- a/include/Nazara/Shader/Ast/SanitizeVisitor.hpp +++ b/include/Nazara/Shader/Ast/SanitizeVisitor.hpp @@ -11,6 +11,7 @@ #include #include #include +#include #include #include #include @@ -19,6 +20,8 @@ namespace Nz::ShaderAst { class NAZARA_SHADER_API SanitizeVisitor final : AstCloner { + friend class AstTypeExpressionVisitor; + public: struct Options; @@ -55,6 +58,7 @@ namespace Nz::ShaderAst struct Identifier; using AstCloner::CloneExpression; + ExpressionValue CloneType(const ExpressionValue& exprType) override; ExpressionPtr Clone(AccessIdentifierExpression& node) override; ExpressionPtr Clone(AccessIndexExpression& node) override; @@ -84,50 +88,58 @@ namespace Nz::ShaderAst StatementPtr Clone(ForStatement& node) override; StatementPtr Clone(ForEachStatement& node) override; StatementPtr Clone(MultiStatement& node) override; + StatementPtr Clone(ScopedStatement& node) override; StatementPtr Clone(WhileStatement& node) override; const Identifier* FindIdentifier(const std::string_view& identifierName) const; + template const Identifier* FindIdentifier(const std::string_view& identifierName, F&& functor) const; + TypeParameter FindTypeParameter(const std::string_view& identifierName) const; - Expression& MandatoryExpr(const ExpressionPtr& node); - Statement& MandatoryStatement(const StatementPtr& node); - void TypeMustMatch(const ExpressionPtr& left, const ExpressionPtr& right); - void TypeMustMatch(const ExpressionType& left, const ExpressionType& right); + Expression& MandatoryExpr(const ExpressionPtr& node) const; + Statement& MandatoryStatement(const StatementPtr& node) const; void PushScope(); void PopScope(); ExpressionPtr CacheResult(ExpressionPtr expression); - template const T& ComputeAttributeValue(AttributeValue& attribute); - ConstantValue ComputeConstantValue(Expression& expr); - template std::unique_ptr Optimize(T& node); - - std::size_t DeclareFunction(DeclareFunctionStatement& funcDecl); + ConstantValue ComputeConstantValue(Expression& expr) const; + template const T& ComputeExprValue(ExpressionValue& attribute) const; + template std::unique_ptr Optimize(T& node) const; void PropagateFunctionFlags(std::size_t funcIndex, FunctionFlags flags, Bitset<>& seen); + void RegisterBuiltin(); std::size_t RegisterConstant(std::string name, ConstantValue value); - FunctionData& RegisterFunction(std::size_t functionIndex); + std::size_t RegisterFunction(std::string name, FunctionData funcData); std::size_t RegisterIntrinsic(std::string name, IntrinsicType type); std::size_t RegisterStruct(std::string name, StructDescription* description); + std::size_t RegisterType(std::string name, ExpressionType expressionType); + std::size_t RegisterType(std::string name, PartialType partialType); std::size_t RegisterVariable(std::string name, ExpressionType type); void ResolveFunctions(); - + const ExpressionPtr& ResolveCondExpression(ConditionalExpression& node); std::size_t ResolveStruct(const ExpressionType& exprType); std::size_t ResolveStruct(const IdentifierType& identifierType); std::size_t ResolveStruct(const StructType& structType); std::size_t ResolveStruct(const UniformType& uniformType); ExpressionType ResolveType(const ExpressionType& exprType); + ExpressionType ResolveType(const ExpressionValue& exprTypeValue); void SanitizeIdentifier(std::string& identifier); + void TypeMustMatch(const ExpressionPtr& left, const ExpressionPtr& right) const; + void TypeMustMatch(const ExpressionType& left, const ExpressionType& right) const; + + StatementPtr Unscope(StatementPtr node); + void Validate(WhileStatement& node); void Validate(AccessIndexExpression& node); void Validate(AssignExpression& node); void Validate(BinaryExpression& node); - void Validate(CallFunctionExpression& node, const DeclareFunctionStatement* referenceDeclaration); + void Validate(CallFunctionExpression& node); void Validate(CastExpression& node); void Validate(DeclareVariableStatement& node); void Validate(IntrinsicExpression& node); @@ -141,7 +153,6 @@ namespace Nz::ShaderAst Bitset<> calledByFunctions; DeclareFunctionStatement* node; FunctionFlags flags; - bool defined = false; }; struct Identifier @@ -153,6 +164,7 @@ namespace Nz::ShaderAst Function, Intrinsic, Struct, + Type, Variable }; diff --git a/include/Nazara/Shader/GlslWriter.hpp b/include/Nazara/Shader/GlslWriter.hpp index 71f7ae2ce..744f9fd8c 100644 --- a/include/Nazara/Shader/GlslWriter.hpp +++ b/include/Nazara/Shader/GlslWriter.hpp @@ -19,7 +19,7 @@ namespace Nz { - class NAZARA_SHADER_API GlslWriter : public ShaderWriter, public ShaderAst::ExpressionVisitorExcept, public ShaderAst::StatementVisitorExcept + class NAZARA_SHADER_API GlslWriter : public ShaderWriter, public ShaderAst::AstExpressionVisitorExcept, public ShaderAst::AstStatementVisitorExcept { public: using BindingMapping = std::unordered_map; @@ -51,15 +51,20 @@ namespace Nz private: void Append(const ShaderAst::ArrayType& type); - void Append(const ShaderAst::ExpressionType& type); void Append(ShaderAst::BuiltinEntry builtin); + void Append(const ShaderAst::ExpressionType& type); + void Append(const ShaderAst::ExpressionValue& type); + void Append(const ShaderAst::FunctionType& functionType); void Append(const ShaderAst::IdentifierType& identifierType); + void Append(const ShaderAst::IntrinsicFunctionType& intrinsicFunctionType); void Append(const ShaderAst::MatrixType& matrixType); + void Append(const ShaderAst::MethodType& methodType); void Append(ShaderAst::MemoryLayout layout); void Append(ShaderAst::NoType); void Append(ShaderAst::PrimitiveType type); void Append(const ShaderAst::SamplerType& samplerType); void Append(const ShaderAst::StructType& structType); + void Append(const ShaderAst::Type& type); void Append(const ShaderAst::UniformType& uniformType); void Append(const ShaderAst::VectorType& vecType); template void Append(const T& param); @@ -81,6 +86,8 @@ namespace Nz void RegisterStruct(std::size_t structIndex, ShaderAst::StructDescription* desc); void RegisterVariable(std::size_t varIndex, std::string varName); + void ScopeVisit(ShaderAst::Statement& node); + void Visit(ShaderAst::ExpressionPtr& expr, bool encloseIfRequired = false); void Visit(ShaderAst::AccessIdentifierExpression& node) override; @@ -107,6 +114,7 @@ namespace Nz void Visit(ShaderAst::MultiStatement& node) override; void Visit(ShaderAst::NoOpStatement& node) override; void Visit(ShaderAst::ReturnStatement& node) override; + void Visit(ShaderAst::ScopedStatement& node) override; void Visit(ShaderAst::WhileStatement& node) override; static bool HasExplicitBinding(ShaderAst::StatementPtr& shader); diff --git a/include/Nazara/Shader/LangWriter.hpp b/include/Nazara/Shader/LangWriter.hpp index 9c7f4123a..914ece90d 100644 --- a/include/Nazara/Shader/LangWriter.hpp +++ b/include/Nazara/Shader/LangWriter.hpp @@ -18,7 +18,7 @@ namespace Nz { - class NAZARA_SHADER_API LangWriter : public ShaderWriter, public ShaderAst::ExpressionVisitorExcept, public ShaderAst::StatementVisitorExcept + class NAZARA_SHADER_API LangWriter : public ShaderWriter, public ShaderAst::AstExpressionVisitorExcept, public ShaderAst::AstStatementVisitorExcept { public: struct Environment; @@ -49,12 +49,17 @@ namespace Nz void Append(const ShaderAst::ArrayType& type); void Append(const ShaderAst::ExpressionType& type); + void Append(const ShaderAst::ExpressionValue& type); + void Append(const ShaderAst::FunctionType& functionType); void Append(const ShaderAst::IdentifierType& identifierType); + void Append(const ShaderAst::IntrinsicFunctionType& intrinsicFunctionType); void Append(const ShaderAst::MatrixType& matrixType); + void Append(const ShaderAst::MethodType& methodType); void Append(ShaderAst::NoType); void Append(ShaderAst::PrimitiveType type); void Append(const ShaderAst::SamplerType& samplerType); void Append(const ShaderAst::StructType& structType); + void Append(const ShaderAst::Type& type); void Append(const ShaderAst::UniformType& uniformType); void Append(const ShaderAst::VectorType& vecType); template void Append(const T& param); @@ -84,6 +89,8 @@ namespace Nz void RegisterStruct(std::size_t structIndex, ShaderAst::StructDescription* desc); void RegisterVariable(std::size_t varIndex, std::string varName); + void ScopeVisit(ShaderAst::Statement& node); + void Visit(ShaderAst::ExpressionPtr& expr, bool encloseIfRequired = false); void Visit(ShaderAst::AccessIdentifierExpression& node) override; @@ -114,6 +121,7 @@ namespace Nz void Visit(ShaderAst::MultiStatement& node) override; void Visit(ShaderAst::NoOpStatement& node) override; void Visit(ShaderAst::ReturnStatement& node) override; + void Visit(ShaderAst::ScopedStatement& node) override; void Visit(ShaderAst::WhileStatement& node) override; struct State; diff --git a/include/Nazara/Shader/ShaderBuilder.hpp b/include/Nazara/Shader/ShaderBuilder.hpp index 5783e8b22..3ac22688b 100644 --- a/include/Nazara/Shader/ShaderBuilder.hpp +++ b/include/Nazara/Shader/ShaderBuilder.hpp @@ -50,13 +50,14 @@ namespace Nz::ShaderBuilder struct CallFunction { inline std::unique_ptr operator()(std::string functionName, std::vector parameters) const; + inline std::unique_ptr operator()(ShaderAst::ExpressionPtr functionExpr, std::vector parameters) const; }; struct Cast { - inline std::unique_ptr operator()(ShaderAst::ExpressionType targetType, ShaderAst::ExpressionPtr expression) const; - inline std::unique_ptr operator()(ShaderAst::ExpressionType targetType, std::array expressions) const; - inline std::unique_ptr operator()(ShaderAst::ExpressionType targetType, std::vector expressions) const; + inline std::unique_ptr operator()(ShaderAst::ExpressionValue targetType, ShaderAst::ExpressionPtr expression) const; + inline std::unique_ptr operator()(ShaderAst::ExpressionValue targetType, std::array expressions) const; + inline std::unique_ptr operator()(ShaderAst::ExpressionValue targetType, std::vector expressions) const; }; struct ConditionalExpression @@ -78,20 +79,20 @@ namespace Nz::ShaderBuilder struct DeclareConst { inline std::unique_ptr operator()(std::string name, ShaderAst::ExpressionPtr initialValue) const; - inline std::unique_ptr operator()(std::string name, ShaderAst::ExpressionType type, ShaderAst::ExpressionPtr initialValue = nullptr) const; + inline std::unique_ptr operator()(std::string name, ShaderAst::ExpressionValue type, ShaderAst::ExpressionPtr initialValue = nullptr) const; }; struct DeclareFunction { inline std::unique_ptr operator()(std::string name, ShaderAst::StatementPtr statement) const; - inline std::unique_ptr operator()(std::string name, std::vector parameters, std::vector statements, ShaderAst::ExpressionType returnType = ShaderAst::NoType{}) const; + inline std::unique_ptr operator()(std::string name, std::vector parameters, std::vector statements, ShaderAst::ExpressionValue returnType = ShaderAst::ExpressionType{ ShaderAst::NoType{} }) const; inline std::unique_ptr operator()(std::optional entryStage, std::string name, ShaderAst::StatementPtr statement) const; - inline std::unique_ptr operator()(std::optional entryStage, std::string name, std::vector parameters, std::vector statements, ShaderAst::ExpressionType returnType = ShaderAst::NoType{}) const; + inline std::unique_ptr operator()(std::optional entryStage, std::string name, std::vector parameters, std::vector statements, ShaderAst::ExpressionValue returnType = ShaderAst::ExpressionType{ ShaderAst::NoType{} }) const; }; struct DeclareOption { - inline std::unique_ptr operator()(std::string name, ShaderAst::ExpressionType type, ShaderAst::ExpressionPtr initialValue = nullptr) const; + inline std::unique_ptr operator()(std::string name, ShaderAst::ExpressionValue type, ShaderAst::ExpressionPtr initialValue = nullptr) const; }; struct DeclareStruct @@ -102,7 +103,7 @@ namespace Nz::ShaderBuilder struct DeclareVariable { inline std::unique_ptr operator()(std::string name, ShaderAst::ExpressionPtr initialValue) const; - inline std::unique_ptr operator()(std::string name, ShaderAst::ExpressionType type, ShaderAst::ExpressionPtr initialValue = nullptr) const; + inline std::unique_ptr operator()(std::string name, ShaderAst::ExpressionValue type, ShaderAst::ExpressionPtr initialValue = nullptr) const; }; struct ExpressionStatement @@ -147,6 +148,11 @@ namespace Nz::ShaderBuilder inline std::unique_ptr operator()(ShaderAst::ExpressionPtr expr = nullptr) const; }; + struct Scoped + { + inline std::unique_ptr operator()(ShaderAst::StatementPtr statement) const; + }; + struct Swizzle { inline std::unique_ptr operator()(ShaderAst::ExpressionPtr expression, std::array swizzleComponents, std::size_t componentCount) const; @@ -194,6 +200,7 @@ namespace Nz::ShaderBuilder constexpr Impl::Multi MultiStatement; constexpr Impl::NoParam NoOp; constexpr Impl::Return Return; + constexpr Impl::Scoped Scoped; constexpr Impl::Swizzle Swizzle; constexpr Impl::Unary Unary; constexpr Impl::Variable Variable; diff --git a/include/Nazara/Shader/ShaderBuilder.inl b/include/Nazara/Shader/ShaderBuilder.inl index 83a2423ab..fa3ca0457 100644 --- a/include/Nazara/Shader/ShaderBuilder.inl +++ b/include/Nazara/Shader/ShaderBuilder.inl @@ -105,13 +105,22 @@ namespace Nz::ShaderBuilder inline std::unique_ptr Impl::CallFunction::operator()(std::string functionName, std::vector parameters) const { auto callFunctionExpression = std::make_unique(); - callFunctionExpression->targetFunction = std::move(functionName); + callFunctionExpression->targetFunction = ShaderBuilder::Identifier(std::move(functionName)); callFunctionExpression->parameters = std::move(parameters); return callFunctionExpression; } - inline std::unique_ptr Impl::Cast::operator()(ShaderAst::ExpressionType targetType, ShaderAst::ExpressionPtr expression) const + inline std::unique_ptr Impl::CallFunction::operator()(ShaderAst::ExpressionPtr functionExpr, std::vector parameters) const + { + auto callFunctionExpression = std::make_unique(); + callFunctionExpression->targetFunction = std::move(functionExpr); + callFunctionExpression->parameters = std::move(parameters); + + return callFunctionExpression; + } + + inline std::unique_ptr Impl::Cast::operator()(ShaderAst::ExpressionValue targetType, ShaderAst::ExpressionPtr expression) const { auto castNode = std::make_unique(); castNode->targetType = std::move(targetType); @@ -120,7 +129,7 @@ namespace Nz::ShaderBuilder return castNode; } - inline std::unique_ptr Impl::Cast::operator()(ShaderAst::ExpressionType targetType, std::array expressions) const + inline std::unique_ptr Impl::Cast::operator()(ShaderAst::ExpressionValue targetType, std::array expressions) const { auto castNode = std::make_unique(); castNode->expressions = std::move(expressions); @@ -129,7 +138,7 @@ namespace Nz::ShaderBuilder return castNode; } - inline std::unique_ptr Impl::Cast::operator()(ShaderAst::ExpressionType targetType, std::vector expressions) const + inline std::unique_ptr Impl::Cast::operator()(ShaderAst::ExpressionValue targetType, std::vector expressions) const { auto castNode = std::make_unique(); castNode->targetType = std::move(targetType); @@ -194,7 +203,7 @@ namespace Nz::ShaderBuilder return declareConstNode; } - inline std::unique_ptr Impl::DeclareConst::operator()(std::string name, ShaderAst::ExpressionType type, ShaderAst::ExpressionPtr initialValue) const + inline std::unique_ptr Impl::DeclareConst::operator()(std::string name, ShaderAst::ExpressionValue type, ShaderAst::ExpressionPtr initialValue) const { auto declareConstNode = std::make_unique(); declareConstNode->name = std::move(name); @@ -213,7 +222,7 @@ namespace Nz::ShaderBuilder return declareFunctionNode; } - inline std::unique_ptr Impl::DeclareFunction::operator()(std::string name, std::vector parameters, std::vector statements, ShaderAst::ExpressionType returnType) const + inline std::unique_ptr Impl::DeclareFunction::operator()(std::string name, std::vector parameters, std::vector statements, ShaderAst::ExpressionValue returnType) const { auto declareFunctionNode = std::make_unique(); declareFunctionNode->name = std::move(name); @@ -236,7 +245,7 @@ namespace Nz::ShaderBuilder return declareFunctionNode; } - inline std::unique_ptr Impl::DeclareFunction::operator()(std::optional entryStage, std::string name, std::vector parameters, std::vector statements, ShaderAst::ExpressionType returnType) const + inline std::unique_ptr Impl::DeclareFunction::operator()(std::optional entryStage, std::string name, std::vector parameters, std::vector statements, ShaderAst::ExpressionValue returnType) const { auto declareFunctionNode = std::make_unique(); declareFunctionNode->name = std::move(name); @@ -250,7 +259,7 @@ namespace Nz::ShaderBuilder return declareFunctionNode; } - inline std::unique_ptr Impl::DeclareOption::operator()(std::string name, ShaderAst::ExpressionType type, ShaderAst::ExpressionPtr initialValue) const + inline std::unique_ptr Impl::DeclareOption::operator()(std::string name, ShaderAst::ExpressionValue type, ShaderAst::ExpressionPtr initialValue) const { auto declareOptionNode = std::make_unique(); declareOptionNode->optName = std::move(name); @@ -277,7 +286,7 @@ namespace Nz::ShaderBuilder return declareVariableNode; } - inline std::unique_ptr Impl::DeclareVariable::operator()(std::string name, ShaderAst::ExpressionType type, ShaderAst::ExpressionPtr initialValue) const + inline std::unique_ptr Impl::DeclareVariable::operator()(std::string name, ShaderAst::ExpressionValue type, ShaderAst::ExpressionPtr initialValue) const { auto declareVariableNode = std::make_unique(); declareVariableNode->varName = std::move(name); @@ -367,6 +376,14 @@ namespace Nz::ShaderBuilder return returnNode; } + inline std::unique_ptr Impl::Scoped::operator()(ShaderAst::StatementPtr statement) const + { + auto scopedNode = std::make_unique(); + scopedNode->statement = std::move(statement); + + return scopedNode; + } + inline std::unique_ptr Impl::Swizzle::operator()(ShaderAst::ExpressionPtr expression, std::array swizzleComponents, std::size_t componentCount) const { assert(componentCount > 0); diff --git a/include/Nazara/Shader/ShaderLangParser.hpp b/include/Nazara/Shader/ShaderLangParser.hpp index c69c505fa..7e4176cc4 100644 --- a/include/Nazara/Shader/ShaderLangParser.hpp +++ b/include/Nazara/Shader/ShaderLangParser.hpp @@ -70,36 +70,31 @@ namespace Nz::ShaderLang // Flow control const Token& Advance(); void Consume(std::size_t count = 1); - std::optional DecodeType(const std::string& identifier); - void EnterScope(); const Token& Expect(const Token& token, TokenType type); const Token& ExpectNot(const Token& token, TokenType type); const Token& Expect(TokenType type); - void LeaveScope(); - bool IsVariableInScope(const std::string_view& identifier) const; - void RegisterVariable(std::string identifier); const Token& Peek(std::size_t advance = 0); - std::vector ParseAttributes(); - void ParseVariableDeclaration(std::string& name, ShaderAst::ExpressionType& type, ShaderAst::ExpressionPtr& initialValue); + std::vector ParseAttributes(); + void ParseVariableDeclaration(std::string& name, ShaderAst::ExpressionValue& type, ShaderAst::ExpressionPtr& initialValue); // Statements ShaderAst::StatementPtr ParseBranchStatement(); ShaderAst::StatementPtr ParseConstStatement(); ShaderAst::StatementPtr ParseDiscardStatement(); - ShaderAst::StatementPtr ParseExternalBlock(std::vector attributes = {}); - ShaderAst::StatementPtr ParseForDeclaration(std::vector attributes = {}); + ShaderAst::StatementPtr ParseExternalBlock(std::vector attributes = {}); + ShaderAst::StatementPtr ParseForDeclaration(std::vector attributes = {}); std::vector ParseFunctionBody(); - ShaderAst::StatementPtr ParseFunctionDeclaration(std::vector attributes = {}); + ShaderAst::StatementPtr ParseFunctionDeclaration(std::vector attributes = {}); ShaderAst::DeclareFunctionStatement::Parameter ParseFunctionParameter(); ShaderAst::StatementPtr ParseOptionDeclaration(); ShaderAst::StatementPtr ParseReturnStatement(); ShaderAst::StatementPtr ParseSingleStatement(); ShaderAst::StatementPtr ParseStatement(); std::vector ParseStatementList(); - ShaderAst::StatementPtr ParseStructDeclaration(std::vector attributes = {}); + ShaderAst::StatementPtr ParseStructDeclaration(std::vector attributes = {}); ShaderAst::StatementPtr ParseVariableDeclaration(); - ShaderAst::StatementPtr ParseWhileStatement(std::vector attributes); + ShaderAst::StatementPtr ParseWhileStatement(std::vector attributes); // Expressions ShaderAst::ExpressionPtr ParseBinOpRhs(int exprPrecedence, ShaderAst::ExpressionPtr lhs); @@ -115,8 +110,7 @@ namespace Nz::ShaderLang ShaderAst::AttributeType ParseIdentifierAsAttributeType(); const std::string& ParseIdentifierAsName(); - ShaderAst::PrimitiveType ParsePrimitiveType(); - ShaderAst::ExpressionType ParseType(); + ShaderAst::ExpressionPtr ParseType(); static int GetTokenPrecedence(TokenType token); @@ -124,8 +118,6 @@ namespace Nz::ShaderLang { std::size_t tokenCount; std::size_t tokenIndex = 0; - std::vector scopeSizes; - std::vector identifiersInScope; std::unique_ptr root; const Token* tokens; }; diff --git a/include/Nazara/Shader/SpirvAstVisitor.hpp b/include/Nazara/Shader/SpirvAstVisitor.hpp index 28e98b157..7f71c0a51 100644 --- a/include/Nazara/Shader/SpirvAstVisitor.hpp +++ b/include/Nazara/Shader/SpirvAstVisitor.hpp @@ -20,7 +20,7 @@ namespace Nz { class SpirvWriter; - class NAZARA_SHADER_API SpirvAstVisitor : public ShaderAst::ExpressionVisitorExcept, public ShaderAst::StatementVisitorExcept + class NAZARA_SHADER_API SpirvAstVisitor : public ShaderAst::AstExpressionVisitorExcept, public ShaderAst::AstStatementVisitorExcept { public: struct EntryPoint; @@ -38,8 +38,8 @@ namespace Nz const Variable& GetVariable(std::size_t varIndex) const; - using ExpressionVisitorExcept::Visit; - using StatementVisitorExcept::Visit; + using AstExpressionVisitorExcept::Visit; + using AstStatementVisitorExcept::Visit; void Visit(ShaderAst::AccessIndexExpression& node) override; void Visit(ShaderAst::AssignExpression& node) override; @@ -60,6 +60,7 @@ namespace Nz void Visit(ShaderAst::MultiStatement& node) override; void Visit(ShaderAst::NoOpStatement& node) override; void Visit(ShaderAst::ReturnStatement& node) override; + void Visit(ShaderAst::ScopedStatement& node) override; void Visit(ShaderAst::SwizzleExpression& node) override; void Visit(ShaderAst::UnaryExpression& node) override; void Visit(ShaderAst::VariableExpression& node) override; diff --git a/include/Nazara/Shader/SpirvExpressionLoad.hpp b/include/Nazara/Shader/SpirvExpressionLoad.hpp index 253ad1776..2ff279f98 100644 --- a/include/Nazara/Shader/SpirvExpressionLoad.hpp +++ b/include/Nazara/Shader/SpirvExpressionLoad.hpp @@ -19,7 +19,7 @@ namespace Nz class SpirvBlock; class SpirvWriter; - class NAZARA_SHADER_API SpirvExpressionLoad : public ShaderAst::ExpressionVisitorExcept + class NAZARA_SHADER_API SpirvExpressionLoad : public ShaderAst::AstExpressionVisitorExcept { public: inline SpirvExpressionLoad(SpirvWriter& writer, SpirvAstVisitor& visitor, SpirvBlock& block); @@ -29,7 +29,7 @@ namespace Nz UInt32 Evaluate(ShaderAst::Expression& node); - using ExpressionVisitorExcept::Visit; + using AstExpressionVisitorExcept::Visit; void Visit(ShaderAst::AccessIndexExpression& node) override; void Visit(ShaderAst::VariableExpression& node) override; @@ -37,6 +37,15 @@ namespace Nz SpirvExpressionLoad& operator=(SpirvExpressionLoad&&) = delete; private: + struct PointerChainAccess + { + std::vector indices; + const ShaderAst::ExpressionType* exprType; + SpirvStorageClass storage; + UInt32 pointerId; + UInt32 pointedTypeId; + }; + struct Pointer { SpirvStorageClass storage; @@ -46,13 +55,20 @@ namespace Nz struct Value { - UInt32 resultId; + UInt32 valueId; + }; + + struct ValueExtraction + { + std::vector indices; + UInt32 typeId; + UInt32 valueId; }; SpirvAstVisitor& m_visitor; SpirvBlock& m_block; SpirvWriter& m_writer; - std::variant m_value; + std::variant m_value; }; } diff --git a/include/Nazara/Shader/SpirvExpressionStore.hpp b/include/Nazara/Shader/SpirvExpressionStore.hpp index f85fa524f..43a58fcb1 100644 --- a/include/Nazara/Shader/SpirvExpressionStore.hpp +++ b/include/Nazara/Shader/SpirvExpressionStore.hpp @@ -19,7 +19,7 @@ namespace Nz class SpirvBlock; class SpirvWriter; - class NAZARA_SHADER_API SpirvExpressionStore : public ShaderAst::ExpressionVisitorExcept + class NAZARA_SHADER_API SpirvExpressionStore : public ShaderAst::AstExpressionVisitorExcept { public: inline SpirvExpressionStore(SpirvWriter& writer, SpirvAstVisitor& visitor, SpirvBlock& block); @@ -29,7 +29,7 @@ namespace Nz void Store(ShaderAst::ExpressionPtr& node, UInt32 resultId); - using ExpressionVisitorExcept::Visit; + using AstExpressionVisitorExcept::Visit; void Visit(ShaderAst::AccessIndexExpression& node) override; void Visit(ShaderAst::SwizzleExpression& node) override; void Visit(ShaderAst::VariableExpression& node) override; diff --git a/src/Nazara/Graphics/BasicMaterial.cpp b/src/Nazara/Graphics/BasicMaterial.cpp index 870d32833..b02e922ff 100644 --- a/src/Nazara/Graphics/BasicMaterial.cpp +++ b/src/Nazara/Graphics/BasicMaterial.cpp @@ -184,8 +184,8 @@ namespace Nz if (!uberShader->HasOption(optionName, &optionPtr)) return InvalidOption; - if (optionPtr->type != ShaderAst::ExpressionType{ ShaderAst::PrimitiveType::Int32 }) - throw std::runtime_error("Location options must be of type i32"); + //if (optionPtr->type != ShaderAst::ExpressionType{ ShaderAst::PrimitiveType::Int32 }) + // throw std::runtime_error("Location options must be of type i32"); return optionPtr->index; }; diff --git a/src/Nazara/Graphics/PhongLightingMaterial.cpp b/src/Nazara/Graphics/PhongLightingMaterial.cpp index 0c6cb5a61..4568d6fc3 100644 --- a/src/Nazara/Graphics/PhongLightingMaterial.cpp +++ b/src/Nazara/Graphics/PhongLightingMaterial.cpp @@ -251,8 +251,8 @@ namespace Nz if (!uberShader->HasOption(optionName, &optionPtr)) return InvalidOption; - if (optionPtr->type != ShaderAst::ExpressionType{ ShaderAst::PrimitiveType::Int32 }) - throw std::runtime_error("Location options must be of type i32"); + //if (optionPtr->type != ShaderAst::ExpressionType{ ShaderAst::PrimitiveType::Int32 }) + // throw std::runtime_error("Location options must be of type i32"); return optionPtr->index; }; diff --git a/src/Nazara/Graphics/Resources/Shaders/basic_material.nzsl b/src/Nazara/Graphics/Resources/Shaders/basic_material.nzsl index a26d2358a..5277c74ff 100644 --- a/src/Nazara/Graphics/Resources/Shaders/basic_material.nzsl +++ b/src/Nazara/Graphics/Resources/Shaders/basic_material.nzsl @@ -10,7 +10,7 @@ option BillboardSizeRotLocation: i32 = -1; // Vertex declaration related options option ColorLocation: i32 = -1; -option PosLocation: i32 = -1; +option PosLocation: i32; option UvLocation: i32 = -1; const HasVertexColor = (ColorLocation >= 0); diff --git a/src/Nazara/Graphics/Resources/Shaders/phong_material.nzsl b/src/Nazara/Graphics/Resources/Shaders/phong_material.nzsl index 984088ed4..446ce857d 100644 --- a/src/Nazara/Graphics/Resources/Shaders/phong_material.nzsl +++ b/src/Nazara/Graphics/Resources/Shaders/phong_material.nzsl @@ -20,7 +20,7 @@ option BillboardSizeRotLocation: i32 = -1; // Vertex declaration related options option ColorLocation: i32 = -1; option NormalLocation: i32 = -1; -option PosLocation: i32 = -1; +option PosLocation: i32; option TangentLocation: i32 = -1; option UvLocation: i32 = -1; diff --git a/src/Nazara/Graphics/UberShader.cpp b/src/Nazara/Graphics/UberShader.cpp index bfd42d196..84a33e426 100644 --- a/src/Nazara/Graphics/UberShader.cpp +++ b/src/Nazara/Graphics/UberShader.cpp @@ -31,11 +31,10 @@ namespace Nz supportedStageType |= stageType; }; - callbacks.onOptionDeclaration = [&](const std::string& optionName, const ShaderAst::ExpressionType& optionType) + callbacks.onOptionDeclaration = [&](const std::string& optionName, const ShaderAst::ExpressionValue& optionType) { m_optionIndexByName[optionName] = Option{ - optionCount, - optionType + optionCount }; optionCount++; diff --git a/src/Nazara/Shader/Ast/AstCloner.cpp b/src/Nazara/Shader/Ast/AstCloner.cpp index 9b4ce5b30..51be1d4c3 100644 --- a/src/Nazara/Shader/Ast/AstCloner.cpp +++ b/src/Nazara/Shader/Ast/AstCloner.cpp @@ -36,6 +36,20 @@ namespace Nz::ShaderAst return PopStatement(); } + ExpressionValue AstCloner::CloneType(const ExpressionValue& exprType) + { + if (!exprType.HasValue()) + return {}; + + if (exprType.IsExpression()) + return CloneExpression(exprType.GetExpression()); + else + { + assert(exprType.IsResultingValue()); + return exprType.GetResultingValue(); + } + } + StatementPtr AstCloner::Clone(BranchStatement& node) { auto clone = std::make_unique(); @@ -68,7 +82,7 @@ namespace Nz::ShaderAst auto clone = std::make_unique(); clone->constIndex = node.constIndex; clone->name = node.name; - clone->type = node.type; + clone->type = Clone(node.type); clone->expression = CloneExpression(node.expression); return clone; @@ -86,7 +100,7 @@ namespace Nz::ShaderAst { auto& cloneVar = clone->externalVars.emplace_back(); cloneVar.name = var.name; - cloneVar.type = var.type; + cloneVar.type = Clone(var.type); cloneVar.bindingIndex = Clone(var.bindingIndex); cloneVar.bindingSet = Clone(var.bindingSet); } @@ -102,10 +116,17 @@ namespace Nz::ShaderAst clone->entryStage = Clone(node.entryStage); clone->funcIndex = node.funcIndex; clone->name = node.name; - clone->parameters = node.parameters; - clone->returnType = node.returnType; + clone->returnType = Clone(node.returnType); clone->varIndex = node.varIndex; + clone->parameters.reserve(node.parameters.size()); + for (auto& parameter : node.parameters) + { + auto& cloneParam = clone->parameters.emplace_back(); + cloneParam.name = parameter.name; + cloneParam.type = Clone(parameter.type); + } + clone->statements.reserve(node.statements.size()); for (auto& statement : node.statements) clone->statements.push_back(CloneStatement(statement)); @@ -119,7 +140,7 @@ namespace Nz::ShaderAst clone->defaultValue = CloneExpression(node.defaultValue); clone->optIndex = node.optIndex; clone->optName = node.optName; - clone->optType = node.optType; + clone->optType = Clone(node.optType); return clone; } @@ -137,7 +158,7 @@ namespace Nz::ShaderAst { auto& cloneMember = clone->description.members.emplace_back(); cloneMember.name = member.name; - cloneMember.type = member.type; + cloneMember.type = Clone(member.type); cloneMember.builtin = Clone(member.builtin); cloneMember.cond = Clone(member.cond); cloneMember.locationIndex = Clone(member.locationIndex); @@ -151,7 +172,7 @@ namespace Nz::ShaderAst auto clone = std::make_unique(); clone->varIndex = node.varIndex; clone->varName = node.varName; - clone->varType = node.varType; + clone->varType = Clone(node.varType); clone->initialExpression = CloneExpression(node.initialExpression); return clone; @@ -217,6 +238,14 @@ namespace Nz::ShaderAst return clone; } + StatementPtr AstCloner::Clone(ScopedStatement& node) + { + auto clone = std::make_unique(); + clone->statement = CloneStatement(node.statement); + + return clone; + } + StatementPtr AstCloner::Clone(WhileStatement& node) { auto clone = std::make_unique(); @@ -279,7 +308,7 @@ namespace Nz::ShaderAst ExpressionPtr AstCloner::Clone(CallFunctionExpression& node) { auto clone = std::make_unique(); - clone->targetFunction = node.targetFunction; + clone->targetFunction = CloneExpression(node.targetFunction); clone->parameters.reserve(node.parameters.size()); for (auto& parameter : node.parameters) @@ -309,7 +338,7 @@ namespace Nz::ShaderAst ExpressionPtr AstCloner::Clone(CastExpression& node) { auto clone = std::make_unique(); - clone->targetType = node.targetType; + clone->targetType = Clone(node.targetType); std::size_t expressionCount = 0; for (auto& expr : node.expressions) diff --git a/src/Nazara/Shader/Ast/AstExpressionVisitorExcept.cpp b/src/Nazara/Shader/Ast/AstExpressionVisitorExcept.cpp index 5aabaece2..ebf24bce6 100644 --- a/src/Nazara/Shader/Ast/AstExpressionVisitorExcept.cpp +++ b/src/Nazara/Shader/Ast/AstExpressionVisitorExcept.cpp @@ -7,7 +7,7 @@ namespace Nz::ShaderAst { -#define NAZARA_SHADERAST_EXPRESSION(Node) void ExpressionVisitorExcept::Visit(ShaderAst::Node& /*node*/) \ +#define NAZARA_SHADERAST_EXPRESSION(Node) void AstExpressionVisitorExcept::Visit(ShaderAst::Node& /*node*/) \ { \ throw std::runtime_error("unexpected " #Node " node"); \ } diff --git a/src/Nazara/Shader/Ast/AstOptimizer.cpp b/src/Nazara/Shader/Ast/AstOptimizer.cpp index 9d3532381..b6bf8bff7 100644 --- a/src/Nazara/Shader/Ast/AstOptimizer.cpp +++ b/src/Nazara/Shader/Ast/AstOptimizer.cpp @@ -818,13 +818,13 @@ namespace Nz::ShaderAst } ExpressionPtr optimized; - if (IsPrimitiveType(node.targetType)) + if (IsPrimitiveType(node.targetType.GetResultingValue())) { if (expressionCount == 1 && expressions.front()->GetType() == NodeType::ConstantValueExpression) { const ConstantValueExpression& constantExpr = static_cast(*expressions.front()); - switch (std::get(node.targetType)) + switch (std::get(node.targetType.GetResultingValue())) { case PrimitiveType::Boolean: optimized = PropagateSingleValueCast(constantExpr); break; case PrimitiveType::Float32: optimized = PropagateSingleValueCast(constantExpr); break; @@ -833,9 +833,9 @@ namespace Nz::ShaderAst } } } - else if (IsVectorType(node.targetType)) + else if (IsVectorType(node.targetType.GetResultingValue())) { - const auto& vecType = std::get(node.targetType); + const auto& vecType = std::get(node.targetType.GetResultingValue()); // Decompose vector into values (cast(vec3, float) => cast(float, float, float, float)) std::vector constantValues; @@ -916,7 +916,7 @@ namespace Nz::ShaderAst if (optimized) return optimized; - auto cast = ShaderBuilder::Cast(node.targetType, std::move(expressions)); + auto cast = ShaderBuilder::Cast(node.targetType.GetResultingValue(), std::move(expressions)); cast->cachedExpressionType = node.cachedExpressionType; return cast; @@ -946,7 +946,7 @@ namespace Nz::ShaderAst if (statements.empty()) { // First condition is true, dismiss the branch - return AstCloner::Clone(*condStatement.statement); + return Unscope(AstCloner::Clone(*condStatement.statement)); } else { @@ -967,7 +967,7 @@ namespace Nz::ShaderAst { // All conditions have been removed, replace by else statement or no-op if (node.elseStatement) - return AstCloner::Clone(*node.elseStatement); + return Unscope(AstCloner::Clone(*node.elseStatement)); else return ShaderBuilder::NoOp(); } @@ -1243,4 +1243,15 @@ namespace Nz::ShaderAst return optimized; } + + + StatementPtr AstOptimizer::Unscope(StatementPtr node) + { + assert(node); + + if (node->GetType() == NodeType::ScopedStatement) + return std::move(static_cast(*node).statement); + else + return node; + } } diff --git a/src/Nazara/Shader/Ast/AstRecursiveVisitor.cpp b/src/Nazara/Shader/Ast/AstRecursiveVisitor.cpp index 061ef2d52..73718cf6b 100644 --- a/src/Nazara/Shader/Ast/AstRecursiveVisitor.cpp +++ b/src/Nazara/Shader/Ast/AstRecursiveVisitor.cpp @@ -202,6 +202,12 @@ namespace Nz::ShaderAst node.returnExpr->Visit(*this); } + void AstRecursiveVisitor::Visit(ScopedStatement& node) + { + if (node.statement) + node.statement->Visit(*this); + } + void AstRecursiveVisitor::Visit(WhileStatement& node) { if (node.condition) diff --git a/src/Nazara/Shader/Ast/AstSerializer.cpp b/src/Nazara/Shader/Ast/AstSerializer.cpp index e5a453f75..e81dd04c5 100644 --- a/src/Nazara/Shader/Ast/AstSerializer.cpp +++ b/src/Nazara/Shader/Ast/AstSerializer.cpp @@ -67,27 +67,7 @@ namespace Nz::ShaderAst void AstSerializerBase::Serialize(CallFunctionExpression& node) { - UInt32 typeIndex; - if (IsWriting()) - typeIndex = UInt32(node.targetFunction.index()); - - Value(typeIndex); - - // Waiting for template lambda in C++20 - auto SerializeValue = [&](auto dummyType) - { - using T = std::decay_t; - - auto& value = (IsWriting()) ? std::get(node.targetFunction) : node.targetFunction.emplace(); - Value(value); - }; - - static_assert(std::variant_size_v == 2); - switch (typeIndex) - { - case 0: SerializeValue(std::string()); break; - case 1: SerializeValue(std::size_t()); break; - } + Node(node.targetFunction); Container(node.parameters); for (auto& param : node.parameters) @@ -106,7 +86,7 @@ namespace Nz::ShaderAst void AstSerializerBase::Serialize(CastExpression& node) { - Type(node.targetType); + ExprValue(node.targetType); for (auto& expr : node.expressions) Node(expr); } @@ -215,15 +195,15 @@ namespace Nz::ShaderAst { OptVal(node.varIndex); - Attribute(node.bindingSet); + ExprValue(node.bindingSet); Container(node.externalVars); for (auto& extVar : node.externalVars) { Value(extVar.name); - Type(extVar.type); - Attribute(extVar.bindingIndex); - Attribute(extVar.bindingSet); + ExprValue(extVar.type); + ExprValue(extVar.bindingIndex); + ExprValue(extVar.bindingSet); } } @@ -231,17 +211,17 @@ namespace Nz::ShaderAst { OptVal(node.constIndex); Value(node.name); - Type(node.type); + ExprValue(node.type); Node(node.expression); } void AstSerializerBase::Serialize(DeclareFunctionStatement& node) { Value(node.name); - Type(node.returnType); - Attribute(node.depthWrite); - Attribute(node.earlyFragmentTests); - Attribute(node.entryStage); + ExprValue(node.returnType); + ExprValue(node.depthWrite); + ExprValue(node.earlyFragmentTests); + ExprValue(node.entryStage); OptVal(node.funcIndex); OptVal(node.varIndex); @@ -249,7 +229,7 @@ namespace Nz::ShaderAst for (auto& parameter : node.parameters) { Value(parameter.name); - Type(parameter.type); + ExprValue(parameter.type); } Container(node.statements); @@ -261,7 +241,7 @@ namespace Nz::ShaderAst { OptVal(node.optIndex); Value(node.optName); - Type(node.optType); + ExprValue(node.optType); Node(node.defaultValue); } @@ -270,16 +250,16 @@ namespace Nz::ShaderAst OptVal(node.structIndex); Value(node.description.name); - Attribute(node.description.layout); + ExprValue(node.description.layout); Container(node.description.members); for (auto& member : node.description.members) { Value(member.name); - Type(member.type); - Attribute(member.builtin); - Attribute(member.cond); - Attribute(member.locationIndex); + ExprValue(member.type); + ExprValue(member.builtin); + ExprValue(member.cond); + ExprValue(member.locationIndex); } } @@ -287,7 +267,7 @@ namespace Nz::ShaderAst { OptVal(node.varIndex); Value(node.varName); - Type(node.varType); + ExprValue(node.varType); Node(node.initialExpression); } @@ -303,7 +283,7 @@ namespace Nz::ShaderAst void AstSerializerBase::Serialize(ForStatement& node) { - Attribute(node.unroll); + ExprValue(node.unroll); Value(node.varName); Node(node.fromExpr); Node(node.toExpr); @@ -313,7 +293,7 @@ namespace Nz::ShaderAst void AstSerializerBase::Serialize(ForEachStatement& node) { - Attribute(node.unroll); + ExprValue(node.unroll); Value(node.varName); Node(node.expression); Node(node.statement); @@ -336,9 +316,14 @@ namespace Nz::ShaderAst Node(node.returnExpr); } + void AstSerializerBase::Serialize(ScopedStatement& node) + { + Node(node.statement); + } + void AstSerializerBase::Serialize(WhileStatement& node) { - Attribute(node.unroll); + ExprValue(node.unroll); Node(node.condition); Node(node.body); } @@ -392,7 +377,7 @@ namespace Nz::ShaderAst else if constexpr (std::is_same_v) { m_stream << UInt8(1); - m_stream << UInt32(arg); + Enum(arg); } else if constexpr (std::is_same_v) { @@ -402,38 +387,59 @@ namespace Nz::ShaderAst else if constexpr (std::is_same_v) { m_stream << UInt8(3); - m_stream << UInt32(arg.columnCount); - m_stream << UInt32(arg.rowCount); - m_stream << UInt32(arg.type); + SizeT(arg.columnCount); + SizeT(arg.rowCount); + Enum(arg.type); } else if constexpr (std::is_same_v) { m_stream << UInt8(4); - m_stream << UInt32(arg.dim); - m_stream << UInt32(arg.sampledType); + Enum(arg.dim); + Enum(arg.sampledType); } else if constexpr (std::is_same_v) { m_stream << UInt8(5); - m_stream << UInt32(arg.structIndex); + SizeT(arg.structIndex); } else if constexpr (std::is_same_v) { m_stream << UInt8(6); - m_stream << std::get(arg.containedType).name; + SizeT(arg.containedType.structIndex); } else if constexpr (std::is_same_v) { m_stream << UInt8(7); - m_stream << UInt32(arg.componentCount); - m_stream << UInt32(arg.type); + SizeT(arg.componentCount); + Enum(arg.type); } else if constexpr (std::is_same_v) { m_stream << UInt8(8); - Attribute(arg.length); + Value(arg.length); Type(arg.containedType->type); } + else if constexpr (std::is_same_v) + { + m_stream << UInt8(9); + SizeT(arg.typeIndex); + } + else if constexpr (std::is_same_v) + { + m_stream << UInt8(10); + SizeT(arg.funcIndex); + } + else if constexpr (std::is_same_v) + { + m_stream << UInt8(11); + Enum(arg.intrinsic); + } + else if constexpr (std::is_same_v) + { + m_stream << UInt8(12); + Type(arg.objectType->type); + SizeT(arg.methodIndex); + } else static_assert(AlwaysFalse::value, "non-exhaustive visitor"); }, type); @@ -618,10 +624,10 @@ namespace Nz::ShaderAst case 3: //< MatrixType { - UInt32 columnCount, rowCount; + std::size_t columnCount, rowCount; PrimitiveType primitiveType; - Value(columnCount); - Value(rowCount); + SizeT(columnCount); + SizeT(rowCount); Enum(primitiveType); type = MatrixType { @@ -659,12 +665,12 @@ namespace Nz::ShaderAst case 6: //< UniformType { - std::string containedType; - Value(containedType); + std::size_t structIndex; + SizeT(structIndex); type = UniformType { - IdentifierType { - containedType + StructType { + structIndex } }; break; @@ -672,9 +678,9 @@ namespace Nz::ShaderAst case 7: //< VectorType { - UInt32 componentCount; + std::size_t componentCount; PrimitiveType componentType; - Value(componentCount); + SizeT(componentCount); Enum(componentType); type = VectorType{ @@ -686,13 +692,13 @@ namespace Nz::ShaderAst case 8: //< ArrayType { - AttributeValue length; + UInt32 length; ExpressionType containedType; - Attribute(length); + Value(length); Type(containedType); ArrayType arrayType; - arrayType.length = std::move(length); + arrayType.length = length; arrayType.containedType = std::make_unique(); arrayType.containedType->type = std::move(containedType); @@ -700,6 +706,52 @@ namespace Nz::ShaderAst break; } + case 9: //< Type + { + std::size_t containedTypeIndex; + SizeT(containedTypeIndex); + + type = ShaderAst::Type{ + containedTypeIndex + }; + } + + case 10: //< FunctionType + { + std::size_t funcIndex; + SizeT(funcIndex); + + type = FunctionType { + funcIndex + }; + } + + case 11: //< IntrinsicFunctionType + { + IntrinsicType intrinsicType; + Enum(intrinsicType); + + type = IntrinsicFunctionType { + intrinsicType + }; + } + + case 12: //< MethodType + { + ExpressionType objectType; + Type(objectType); + + std::size_t methodIndex; + SizeT(methodIndex); + + MethodType methodType; + methodType.objectType = std::make_unique(); + methodType.objectType->type = std::move(objectType); + methodType.methodIndex = methodIndex; + + type = std::move(methodType); + } + default: break; } diff --git a/src/Nazara/Shader/Ast/AstStatementVisitorExcept.cpp b/src/Nazara/Shader/Ast/AstStatementVisitorExcept.cpp index e0c658117..9b46f5bed 100644 --- a/src/Nazara/Shader/Ast/AstStatementVisitorExcept.cpp +++ b/src/Nazara/Shader/Ast/AstStatementVisitorExcept.cpp @@ -7,7 +7,7 @@ namespace Nz::ShaderAst { -#define NAZARA_SHADERAST_STATEMENT(Node) void StatementVisitorExcept::Visit(ShaderAst::Node& /*node*/) \ +#define NAZARA_SHADERAST_STATEMENT(Node) void AstStatementVisitorExcept::Visit(ShaderAst::Node& /*node*/) \ { \ throw std::runtime_error("unexpected " #Node " node"); \ } diff --git a/src/Nazara/Shader/Ast/ExpressionType.cpp b/src/Nazara/Shader/Ast/ExpressionType.cpp index 186bbc20b..7bc68b492 100644 --- a/src/Nazara/Shader/Ast/ExpressionType.cpp +++ b/src/Nazara/Shader/Ast/ExpressionType.cpp @@ -9,11 +9,11 @@ namespace Nz::ShaderAst { - ArrayType::ArrayType(const ArrayType& array) + ArrayType::ArrayType(const ArrayType& array) : + length(array.length) { assert(array.containedType); containedType = std::make_unique(*array.containedType); - length = Clone(array.length); } ArrayType& ArrayType::operator=(const ArrayType& array) @@ -21,7 +21,7 @@ namespace Nz::ShaderAst assert(array.containedType); containedType = std::make_unique(*array.containedType); - length = Clone(array.length); + length = array.length; return *this; } @@ -34,9 +34,34 @@ namespace Nz::ShaderAst if (containedType->type != rhs.containedType->type) return false; - if (!Compare(length, rhs.length)) + if (length != rhs.length) return false; return true; } + + + MethodType::MethodType(const MethodType& methodType) : + methodIndex(methodType.methodIndex) + { + assert(methodType.objectType); + objectType = std::make_unique(*methodType.objectType); + } + + MethodType& MethodType::operator=(const MethodType& methodType) + { + assert(methodType.objectType); + + methodIndex = methodType.methodIndex; + objectType = std::make_unique(*methodType.objectType); + + return *this; + } + + bool MethodType::operator==(const MethodType& rhs) const + { + assert(objectType); + assert(rhs.objectType); + return objectType->type == rhs.objectType->type && methodIndex == rhs.methodIndex; + } } diff --git a/src/Nazara/Shader/Ast/SanitizeVisitor.cpp b/src/Nazara/Shader/Ast/SanitizeVisitor.cpp index c4dd366f7..741f1c784 100644 --- a/src/Nazara/Shader/Ast/SanitizeVisitor.cpp +++ b/src/Nazara/Shader/Ast/SanitizeVisitor.cpp @@ -6,6 +6,7 @@ #include #include #include +#include #include #include #include @@ -41,18 +42,31 @@ namespace Nz::ShaderAst FunctionFlags flags; }; + struct PendingFunction + { + DeclareFunctionStatement* cloneNode; + const DeclareFunctionStatement* node; + }; + + struct Scope + { + std::size_t previousSize; + }; + std::size_t nextOptionIndex = 0; Options options; std::array entryFunctions = {}; std::unordered_set declaredExternalVar; std::unordered_set usedBindingIndexes; - std::vector identifiersInScope; std::vector constantValues; std::vector functions; + std::vector identifiersInScope; std::vector intrinsics; + std::vector pendingFunctions; std::vector structs; + std::vector> types; std::vector variableTypes; - std::vector scopeSizes; + std::vector scopes; CurrentFunctionData* currentFunction = nullptr; std::vector* currentStatementList = nullptr; }; @@ -69,41 +83,9 @@ namespace Nz::ShaderAst PushScope(); //< Global scope { - RegisterIntrinsic("cross", IntrinsicType::CrossProduct); - RegisterIntrinsic("dot", IntrinsicType::DotProduct); - RegisterIntrinsic("exp", IntrinsicType::Exp); - RegisterIntrinsic("length", IntrinsicType::Length); - RegisterIntrinsic("max", IntrinsicType::Max); - RegisterIntrinsic("min", IntrinsicType::Min); - RegisterIntrinsic("normalize", IntrinsicType::Normalize); - RegisterIntrinsic("pow", IntrinsicType::Pow); - RegisterIntrinsic("reflect", IntrinsicType::Reflect); - - // Collect function name and their types - if (statement.GetType() == NodeType::MultiStatement) - { - const MultiStatement& multiStatement = static_cast(statement); - for (auto& statementPtr : multiStatement.statements) - { - if (statementPtr->GetType() == NodeType::DeclareFunctionStatement) - DeclareFunction(static_cast(*statementPtr)); - else if (statementPtr->GetType() == NodeType::ConditionalStatement) - { - const ConditionalStatement& condStatement = static_cast(*statementPtr); - if (condStatement.statement->GetType() == NodeType::DeclareFunctionStatement) - DeclareFunction(static_cast(*condStatement.statement)); - } - } - } - else if (statement.GetType() == NodeType::DeclareFunctionStatement) - DeclareFunction(static_cast(statement)); - else if (statement.GetType() == NodeType::ConditionalStatement) - { - const ConditionalStatement& condStatement = static_cast(statement); - if (condStatement.statement->GetType() == NodeType::DeclareFunctionStatement) - DeclareFunction(static_cast(*condStatement.statement)); - } + RegisterBuiltin(); + // First pass, evaluate everything except function code try { clone = AstCloner::Clone(statement); @@ -115,6 +97,13 @@ namespace Nz::ShaderAst *error = err.errMsg; } + catch (const std::runtime_error& err) + { + if (!error) + throw; + + *error = err.what(); + } ResolveFunctions(); } @@ -152,6 +141,14 @@ namespace Nz::ShaderAst } } + ExpressionValue SanitizeVisitor::CloneType(const ExpressionValue& exprType) + { + if (!exprType.HasValue()) + return {}; + + return ResolveType(exprType); + } + ExpressionPtr SanitizeVisitor::Clone(AccessIdentifierExpression& node) { if (node.identifiers.empty()) @@ -164,7 +161,28 @@ namespace Nz::ShaderAst throw AstError{ "empty identifier" }; const ExpressionType& exprType = GetExpressionType(*indexedExpr); - if (IsStructType(exprType)) + // TODO: Add proper support for methods + if (IsSamplerType(exprType)) + { + if (identifier == "Sample") + { + // TODO: Add a MethodExpression? + auto identifierExpr = std::make_unique(); + identifierExpr->expr = std::move(indexedExpr); + identifierExpr->identifiers.push_back(identifier); + + MethodType methodType; + methodType.methodIndex = 0; //< FIXME + methodType.objectType = std::make_unique(); + methodType.objectType->type = exprType; + + identifierExpr->cachedExpressionType = std::move(methodType); + indexedExpr = std::move(identifierExpr); + } + else + throw AstError{ "type has no method " + identifier }; + } + else if (IsStructType(exprType)) { std::size_t structIndex = ResolveStruct(exprType); assert(structIndex < m_context->structs.size()); @@ -211,20 +229,12 @@ namespace Nz::ShaderAst else { // Transform to AccessIndexExpression - AccessIndexExpression* accessIndexPtr; - if (indexedExpr->GetType() != NodeType::AccessIndexExpression) - { - std::unique_ptr accessIndex = std::make_unique(); - accessIndex->expr = std::move(indexedExpr); + std::unique_ptr accessIndex = std::make_unique(); + accessIndex->expr = std::move(indexedExpr); + accessIndex->indices.push_back(ShaderBuilder::Constant(fieldIndex)); + accessIndex->cachedExpressionType = ResolveType(fieldPtr->type); - accessIndexPtr = accessIndex.get(); - indexedExpr = std::move(accessIndex); - } - else - accessIndexPtr = static_cast(indexedExpr.get()); - - accessIndexPtr->indices.push_back(ShaderBuilder::Constant(fieldIndex)); - accessIndexPtr->cachedExpressionType = ResolveType(fieldPtr->type); + indexedExpr = std::move(accessIndex); } } else if (IsPrimitiveType(exprType) || IsVectorType(exprType)) @@ -255,7 +265,7 @@ namespace Nz::ShaderAst baseType = std::get(exprType); auto cast = std::make_unique(); - cast->targetType = VectorType{ swizzleComponentCount, baseType }; + cast->targetType = ExpressionType{ VectorType{ swizzleComponentCount, baseType } }; for (std::size_t j = 0; j < swizzleComponentCount; ++j) cast->expressions[j] = CloneExpression(indexedExpr); @@ -319,67 +329,78 @@ namespace Nz::ShaderAst ExpressionPtr SanitizeVisitor::Clone(CallFunctionExpression& node) { - if (!m_context->currentFunction) - throw AstError{ "function calls must happen inside a function" }; + ExpressionPtr targetExpr = CloneExpression(MandatoryExpr(node.targetFunction)); + const ExpressionType& targetExprType = GetExpressionType(*targetExpr); - auto clone = std::make_unique(); - - clone->parameters.reserve(node.parameters.size()); - for (const auto& parameter : node.parameters) - clone->parameters.push_back(CloneExpression(parameter)); - - std::size_t targetFuncIndex; - if (std::holds_alternative(node.targetFunction)) + if (IsFunctionType(targetExprType)) { - const std::string& functionName = std::get(node.targetFunction); + if (!m_context->currentFunction) + throw AstError{ "function calls must happen inside a function" }; - const Identifier* identifier = FindIdentifier(functionName); - if (identifier) - { - if (identifier->type == Identifier::Type::Intrinsic) - { - // Intrinsic function call - std::vector parameters; - parameters.reserve(node.parameters.size()); + std::size_t targetFuncIndex = std::get(targetExprType).funcIndex; - for (const auto& param : node.parameters) - parameters.push_back(CloneExpression(param)); + auto clone = std::make_unique(); + clone->targetFunction = std::move(targetExpr); - auto intrinsic = ShaderBuilder::Intrinsic(m_context->intrinsics[identifier->index], std::move(parameters)); - Validate(*intrinsic); + clone->parameters.reserve(node.parameters.size()); + for (const auto& parameter : node.parameters) + clone->parameters.push_back(CloneExpression(parameter)); - return intrinsic; - } - else - { - // Regular function call - if (identifier->type != Identifier::Type::Function) - throw AstError{ "function expected" }; + m_context->currentFunction->calledFunctions.UnboundedSet(targetFuncIndex); - clone->targetFunction = identifier->index; - targetFuncIndex = identifier->index; - } - } - else - { - // Identifier not found, maybe the function is declared later - auto it = std::find_if(m_context->functions.begin(), m_context->functions.end(), [&](const auto& funcData) { return funcData.node->name == functionName; }); - if (it == m_context->functions.end()) - throw AstError{ "function " + functionName + " does not exist" }; + Validate(*clone); - targetFuncIndex = std::distance(m_context->functions.begin(), it); + return clone; + } + else if (IsIntrinsicFunctionType(targetExprType)) + { + std::vector parameters; + parameters.reserve(node.parameters.size()); - clone->targetFunction = targetFuncIndex; - } + for (const auto& param : node.parameters) + parameters.push_back(CloneExpression(param)); + + auto intrinsic = ShaderBuilder::Intrinsic(std::get(targetExprType).intrinsic, std::move(parameters)); + Validate(*intrinsic); + + return intrinsic; + } + else if (IsMethodType(targetExprType)) + { + const MethodType& methodType = std::get(targetExprType); + + std::vector parameters; + parameters.reserve(node.parameters.size() + 1); + + // TODO: Add MethodExpression + assert(targetExpr->GetType() == NodeType::AccessIdentifierExpression); + + parameters.push_back(std::move(static_cast(*targetExpr).expr)); + for (const auto& param : node.parameters) + parameters.push_back(CloneExpression(param)); + + assert(IsSamplerType(methodType.objectType->type) && methodType.methodIndex == 0); + auto intrinsic = ShaderBuilder::Intrinsic(IntrinsicType::SampleTexture, std::move(parameters)); + Validate(*intrinsic); + + return intrinsic; } else - targetFuncIndex = std::get(node.targetFunction); + { + // Calling a type - vec3[f32](0.0, 1.0, 2.0) - it's a cast + auto clone = std::make_unique(); + clone->targetType = std::move(targetExprType); - m_context->currentFunction->calledFunctions.UnboundedSet(targetFuncIndex); + if (node.parameters.size() > clone->expressions.size()) + throw AstError{ "component count doesn't match required component count" }; - Validate(*clone, m_context->functions[targetFuncIndex].node); + for (std::size_t i = 0; i < node.parameters.size(); ++i) + clone->expressions[i] = CloneExpression(node.parameters[i]); - return clone; + Validate(*clone); + + return Clone(*clone); //< Necessary because cast has to be modified (FIXME) + } } ExpressionPtr SanitizeVisitor::Clone(CastExpression& node) @@ -387,9 +408,11 @@ namespace Nz::ShaderAst auto clone = static_unique_pointer_cast(AstCloner::Clone(node)); Validate(*clone); - if (m_context->options.removeMatrixCast && IsMatrixType(clone->targetType)) + const ExpressionType& targetType = clone->targetType.GetResultingValue(); + + if (m_context->options.removeMatrixCast && IsMatrixType(targetType)) { - const MatrixType& targetMatrixType = std::get(clone->targetType); + const MatrixType& targetMatrixType = std::get(targetType); const ShaderAst::ExpressionType& frontExprType = GetExpressionType(*clone->expressions.front()); bool isMatrixCast = IsMatrixType(frontExprType); @@ -399,7 +422,7 @@ namespace Nz::ShaderAst return std::move(clone->expressions.front()); } - auto variableDeclaration = ShaderBuilder::DeclareVariable("temp", clone->targetType); //< Validation will prevent name-clash if required + auto variableDeclaration = ShaderBuilder::DeclareVariable("temp", targetType); //< Validation will prevent name-clash if required Validate(*variableDeclaration); std::size_t variableIndex = *variableDeclaration->varIndex; @@ -409,7 +432,7 @@ namespace Nz::ShaderAst for (std::size_t i = 0; i < targetMatrixType.columnCount; ++i) { // temp[i] - auto columnExpr = ShaderBuilder::AccessIndex(ShaderBuilder::Variable(variableIndex, clone->targetType), ShaderBuilder::Constant(UInt32(i))); + auto columnExpr = ShaderBuilder::AccessIndex(ShaderBuilder::Variable(variableIndex, targetType), ShaderBuilder::Constant(UInt32(i))); Validate(*columnExpr); // vector expression @@ -441,9 +464,9 @@ namespace Nz::ShaderAst std::array expressions; expressions[0] = std::move(vectorExpr); for (std::size_t j = 0; j < targetMatrixType.rowCount - vectorComponentCount; ++j) - expressions[j + 1] = ShaderBuilder::Constant(targetMatrixType.type, (i == j + vectorComponentCount) ? 1 : 0); //< set 1 to diagonal + expressions[j + 1] = ShaderBuilder::Constant(ExpressionType{ targetMatrixType.type }, (i == j + vectorComponentCount) ? 1 : 0); //< set 1 to diagonal - vecCast = ShaderBuilder::Cast(VectorType{ targetMatrixType.rowCount, targetMatrixType.type }, std::move(expressions)); + vecCast = ShaderBuilder::Cast(ExpressionType{ VectorType{ targetMatrixType.rowCount, targetMatrixType.type } }, std::move(expressions)); Validate(*vecCast); castExpr = std::move(vecCast); @@ -466,7 +489,7 @@ namespace Nz::ShaderAst m_context->currentStatementList->emplace_back(ShaderBuilder::ExpressionStatement(ShaderBuilder::Assign(AssignType::Simple, std::move(columnExpr), std::move(castExpr)))); } - return ShaderBuilder::Variable(variableIndex, clone->targetType); + return ShaderBuilder::Variable(variableIndex, targetType); } return clone; @@ -474,18 +497,7 @@ namespace Nz::ShaderAst ExpressionPtr SanitizeVisitor::Clone(ConditionalExpression& node) { - MandatoryExpr(node.condition); - MandatoryExpr(node.truePath); - MandatoryExpr(node.falsePath); - - ConstantValue conditionValue = ComputeConstantValue(*AstCloner::Clone(*node.condition)); - if (GetExpressionType(conditionValue) != ExpressionType{ PrimitiveType::Boolean }) - throw AstError{ "expected a boolean value" }; - - if (std::get(conditionValue)) - return AstCloner::Clone(*node.truePath); - else - return AstCloner::Clone(*node.falsePath); + return AstCloner::Clone(*ResolveCondExpression(node)); } ExpressionPtr SanitizeVisitor::Clone(ConstantValueExpression& node) @@ -530,6 +542,41 @@ namespace Nz::ShaderAst return Clone(constantExpr); //< Turn ConstantExpression into ConstantValueExpression } + case Identifier::Type::Function: + { + auto clone = AstCloner::Clone(node); + clone->cachedExpressionType = FunctionType{ identifier->index }; + + return clone; + } + + case Identifier::Type::Intrinsic: + { + assert(identifier->index < m_context->intrinsics.size()); + IntrinsicType intrinsicType = m_context->intrinsics[identifier->index]; + + auto clone = AstCloner::Clone(node); + clone->cachedExpressionType = IntrinsicFunctionType{ intrinsicType }; + + return clone; + } + + case Identifier::Type::Struct: + { + auto clone = AstCloner::Clone(node); + clone->cachedExpressionType = StructType{ identifier->index }; + + return clone; + } + + case Identifier::Type::Type: + { + auto clone = AstCloner::Clone(node); + clone->cachedExpressionType = Type{ identifier->index }; + + return clone; + } + case Identifier::Type::Variable: { // Replace IdentifierExpression by VariableExpression @@ -541,7 +588,7 @@ namespace Nz::ShaderAst } default: - throw AstError{ "expected constant or variable identifier" }; + throw AstError{ "unexpected identifier" }; } } @@ -591,12 +638,12 @@ namespace Nz::ShaderAst throw AstError{ "expected a boolean value" }; if (std::get(conditionValue)) - return AstCloner::Clone(*cond.statement); + return Unscope(AstCloner::Clone(*cond.statement)); } // Every condition failed, fallback to else if any if (node.elseStatement) - return AstCloner::Clone(*node.elseStatement); + return Unscope(AstCloner::Clone(*node.elseStatement)); else return ShaderBuilder::NoOp(); } @@ -680,7 +727,7 @@ namespace Nz::ShaderAst ExpressionType expressionType = ResolveType(GetExpressionType(value)); - if (!IsNoType(clone->type) && ResolveType(clone->type) != expressionType) + if (clone->type.HasValue() && ResolveType(clone->type) != expressionType) throw AstError{ "constant expression doesn't match type" }; clone->type = expressionType; @@ -701,7 +748,7 @@ namespace Nz::ShaderAst UInt32 defaultBlockSet = 0; if (clone->bindingSet.HasValue()) - defaultBlockSet = ComputeAttributeValue(clone->bindingSet); + defaultBlockSet = ComputeExprValue(clone->bindingSet); for (auto& extVar : clone->externalVars) { @@ -709,13 +756,13 @@ namespace Nz::ShaderAst throw AstError{ "external variable " + extVar.name + " requires a binding index" }; if (extVar.bindingSet.HasValue()) - ComputeAttributeValue(extVar.bindingSet); + ComputeExprValue(extVar.bindingSet); else extVar.bindingSet = defaultBlockSet; UInt64 bindingSet = extVar.bindingSet.GetResultingValue(); - UInt64 bindingIndex = ComputeAttributeValue(extVar.bindingIndex); + UInt64 bindingIndex = ComputeExprValue(extVar.bindingIndex); UInt64 bindingKey = bindingSet << 32 | bindingIndex; if (m_context->usedBindingIndexes.find(bindingKey) != m_context->usedBindingIndexes.end()) @@ -728,16 +775,18 @@ namespace Nz::ShaderAst m_context->declaredExternalVar.insert(extVar.name); - extVar.type = ResolveType(extVar.type); + ExpressionType resolvedType = ResolveType(extVar.type); ExpressionType varType; - if (IsUniformType(extVar.type)) - varType = std::get(std::get(extVar.type).containedType); - else if (IsSamplerType(extVar.type)) - varType = extVar.type; + if (IsUniformType(resolvedType)) + varType = std::get(resolvedType).containedType; + else if (IsSamplerType(resolvedType)) + varType = resolvedType; else throw AstError{ "external variable " + extVar.name + " is of wrong type: only uniform and sampler are allowed in external blocks" }; + extVar.type = std::move(resolvedType); + std::size_t varIndex = RegisterVariable(extVar.name, std::move(varType)); if (!clone->varIndex) clone->varIndex = varIndex; //< First external variable index is node variable index @@ -755,17 +804,26 @@ namespace Nz::ShaderAst auto clone = std::make_unique(); clone->name = node.name; - clone->parameters = node.parameters; - clone->returnType = ResolveType(node.returnType); + + clone->parameters.reserve(node.parameters.size()); + for (auto& parameter : node.parameters) + { + auto& cloneParam = clone->parameters.emplace_back(); + cloneParam.name = parameter.name; + cloneParam.type = ResolveType(parameter.type); + } + + if (node.returnType.HasValue()) + clone->returnType = ResolveType(node.returnType); if (node.depthWrite.HasValue()) - clone->depthWrite = ComputeAttributeValue(node.depthWrite); + clone->depthWrite = ComputeExprValue(node.depthWrite); if (node.earlyFragmentTests.HasValue()) - clone->earlyFragmentTests = ComputeAttributeValue(node.earlyFragmentTests); + clone->earlyFragmentTests = ComputeExprValue(node.earlyFragmentTests); if (node.entryStage.HasValue()) - clone->entryStage = ComputeAttributeValue(node.entryStage); + clone->entryStage = ComputeExprValue(node.entryStage); if (clone->entryStage.HasValue()) { @@ -789,35 +847,13 @@ namespace Nz::ShaderAst } } - Context::CurrentFunctionData tempFuncData; - if (node.entryStage.HasValue()) - tempFuncData.stageType = node.entryStage.GetResultingValue(); + // Function content is resolved in a second pass + auto& pendingFunc = m_context->pendingFunctions.emplace_back(); + pendingFunc.cloneNode = clone.get(); + pendingFunc.node = &node; - m_context->currentFunction = &tempFuncData; - - std::vector* previousList = m_context->currentStatementList; - m_context->currentStatementList = &clone->statements; - - PushScope(); - { - for (auto& parameter : clone->parameters) - { - parameter.type = ResolveType(parameter.type); - std::size_t varIndex = RegisterVariable(parameter.name, parameter.type); - if (!clone->varIndex) - clone->varIndex = varIndex; //< First parameter variable index is node variable index - - SanitizeIdentifier(parameter.name); - } - - clone->statements.reserve(node.statements.size()); - for (auto& statement : node.statements) - clone->statements.push_back(CloneStatement(MandatoryStatement(statement))); - } - PopScope(); - - m_context->currentStatementList = previousList; - m_context->currentFunction = nullptr; + for (auto& parameter : clone->parameters) + parameter.type = ResolveType(parameter.type); if (clone->earlyFragmentTests.HasValue() && clone->earlyFragmentTests.GetResultingValue()) { @@ -825,24 +861,12 @@ namespace Nz::ShaderAst throw AstError{ "discard is not compatible with early fragment tests" }; } - auto it = std::find_if(m_context->functions.begin(), m_context->functions.end(), [&](const auto& funcData) { return funcData.node == &node; }); - assert(it != m_context->functions.end()); - assert(!it->defined); - - std::size_t funcIndex = std::distance(m_context->functions.begin(), it); + FunctionData funcData; + funcData.node = clone.get(); //< update function node + std::size_t funcIndex = RegisterFunction(clone->name, std::move(funcData)); clone->funcIndex = funcIndex; - auto& funcData = RegisterFunction(funcIndex); - funcData.flags = tempFuncData.flags; - - for (std::size_t i = tempFuncData.calledFunctions.FindFirst(); i != tempFuncData.calledFunctions.npos; i = tempFuncData.calledFunctions.FindNext(i)) - { - assert(i < m_context->functions.size()); - auto& targetFunc = m_context->functions[i]; - targetFunc.calledByFunctions.UnboundedSet(funcIndex); - } - SanitizeIdentifier(clone->name); return clone; @@ -854,11 +878,14 @@ namespace Nz::ShaderAst throw AstError{ "options must be declared outside of functions" }; auto clone = static_unique_pointer_cast(AstCloner::Clone(node)); - clone->optType = ResolveType(clone->optType); - if (clone->defaultValue && clone->optType != GetExpressionType(*clone->defaultValue)) + ExpressionType resolvedType = ResolveType(clone->optType); + + if (clone->defaultValue && resolvedType != GetExpressionType(*clone->defaultValue)) throw AstError{ "option " + clone->optName + " default expression must be of the same type than the option" }; + clone->optType = std::move(resolvedType); + std::size_t optionIndex = m_context->nextOptionIndex++; if (auto optionValueIt = m_context->options.optionValues.find(optionIndex); optionValueIt != m_context->options.optionValues.end()) @@ -886,35 +913,37 @@ namespace Nz::ShaderAst { if (member.cond.HasValue()) { - member.cond = ComputeAttributeValue(member.cond); + member.cond = ComputeExprValue(member.cond); if (!member.cond.GetResultingValue()) continue; } if (member.builtin.HasValue()) - member.builtin = ComputeAttributeValue(member.builtin); + member.builtin = ComputeExprValue(member.builtin); if (member.locationIndex.HasValue()) - member.locationIndex = ComputeAttributeValue(member.locationIndex); + member.locationIndex = ComputeExprValue(member.locationIndex); if (declaredMembers.find(member.name) != declaredMembers.end()) throw AstError{ "struct member " + member.name + " found multiple time" }; declaredMembers.insert(member.name); - member.type = ResolveType(member.type); + ExpressionType resolvedType = ResolveType(member.type); if (clone->description.layout.HasValue() && clone->description.layout.GetResultingValue() == StructLayout::Std140) { - if (IsPrimitiveType(member.type) && std::get(member.type) == PrimitiveType::Boolean) + if (IsPrimitiveType(resolvedType) && std::get(resolvedType) == PrimitiveType::Boolean) throw AstError{ "boolean type is not allowed in std140 layout" }; - else if (IsStructType(member.type)) + else if (IsStructType(resolvedType)) { - std::size_t structIndex = std::get(member.type).structIndex; + std::size_t structIndex = std::get(resolvedType).structIndex; const StructDescription* desc = m_context->structs[structIndex]; if (!desc->layout.HasValue() || desc->layout.GetResultingValue() != clone->description.layout.GetResultingValue()) throw AstError{ "inner struct layout mismatch" }; } } + + member.type = std::move(resolvedType); } clone->structIndex = RegisterStruct(clone->description.name, &clone->description); @@ -983,10 +1012,10 @@ namespace Nz::ShaderAst } - AttributeValue unrollValue; + ExpressionValue unrollValue; if (node.unroll.HasValue()) { - unrollValue = ComputeAttributeValue(node.unroll); + unrollValue = ComputeExprValue(node.unroll); if (unrollValue.GetResultingValue() == LoopUnroll::Always) { PushScope(); @@ -1007,7 +1036,7 @@ namespace Nz::ShaderAst Validate(*var); multi->statements.emplace_back(std::move(var)); - multi->statements.emplace_back(CloneStatement(node.statement)); + multi->statements.emplace_back(Unscope(CloneStatement(node.statement))); } }; @@ -1077,7 +1106,7 @@ namespace Nz::ShaderAst auto body = std::make_unique(); body->statements.reserve(2); - body->statements.emplace_back(CloneStatement(node.statement)); + body->statements.emplace_back(Unscope(CloneStatement(node.statement))); ExpressionPtr incrExpr; if (stepVarIndex) @@ -1137,10 +1166,10 @@ namespace Nz::ShaderAst else throw AstError{ "for-each is only supported on arrays and range expressions" }; - AttributeValue unrollValue; + ExpressionValue unrollValue; if (node.unroll.HasValue()) { - unrollValue = ComputeAttributeValue(node.unroll); + unrollValue = ComputeExprValue(node.unroll); if (unrollValue.GetResultingValue() == LoopUnroll::Always) { PushScope(); @@ -1150,9 +1179,8 @@ namespace Nz::ShaderAst if (IsArrayType(exprType)) { const ArrayType& arrayType = std::get(exprType); - UInt32 length = arrayType.length.GetResultingValue(); - for (UInt32 i = 0; i < length; ++i) + for (UInt32 i = 0; i < arrayType.length; ++i) { auto accessIndex = ShaderBuilder::AccessIndex(CloneExpression(expr), ShaderBuilder::Constant(i)); Validate(*accessIndex); @@ -1161,7 +1189,7 @@ namespace Nz::ShaderAst Validate(*elementVariable); multi->statements.emplace_back(std::move(elementVariable)); - multi->statements.emplace_back(CloneStatement(node.statement)); + multi->statements.emplace_back(Unscope(CloneStatement(node.statement))); } } @@ -1180,7 +1208,6 @@ namespace Nz::ShaderAst if (IsArrayType(exprType)) { const ArrayType& arrayType = std::get(exprType); - UInt32 length = arrayType.length.GetResultingValue(); multi->statements.reserve(2); @@ -1196,7 +1223,7 @@ namespace Nz::ShaderAst whileStatement->unroll = std::move(unrollValue); // While condition - auto condition = ShaderBuilder::Binary(BinaryType::CompLt, ShaderBuilder::Variable(counterVarIndex, PrimitiveType::UInt32), ShaderBuilder::Constant(length)); + auto condition = ShaderBuilder::Binary(BinaryType::CompLt, ShaderBuilder::Variable(counterVarIndex, PrimitiveType::UInt32), ShaderBuilder::Constant(arrayType.length)); Validate(*condition); whileStatement->condition = std::move(condition); @@ -1211,7 +1238,7 @@ namespace Nz::ShaderAst Validate(*elementVariable); body->statements.emplace_back(std::move(elementVariable)); - body->statements.emplace_back(CloneStatement(node.statement)); + body->statements.emplace_back(Unscope(CloneStatement(node.statement))); auto incrCounter = ShaderBuilder::Assign(AssignType::CompoundAdd, ShaderBuilder::Variable(counterVarIndex, PrimitiveType::UInt32), ShaderBuilder::Constant(1u)); Validate(*incrCounter); @@ -1249,8 +1276,6 @@ namespace Nz::ShaderAst StatementPtr SanitizeVisitor::Clone(MultiStatement& node) { - PushScope(); - auto clone = std::make_unique(); clone->statements.reserve(node.statements.size()); @@ -1262,9 +1287,20 @@ namespace Nz::ShaderAst m_context->currentStatementList = previousList; + return clone; + } + + StatementPtr SanitizeVisitor::Clone(ScopedStatement& node) + { + MandatoryStatement(node.statement); + + PushScope(); + + auto scopedClone = AstCloner::Clone(node); + PopScope(); - return clone; + return scopedClone; } StatementPtr SanitizeVisitor::Clone(WhileStatement& node) @@ -1275,10 +1311,10 @@ namespace Nz::ShaderAst auto clone = static_unique_pointer_cast(AstCloner::Clone(node)); Validate(*clone); - AttributeValue unrollValue; + ExpressionValue unrollValue; if (node.unroll.HasValue()) { - clone->unroll = ComputeAttributeValue(node.unroll); + clone->unroll = ComputeExprValue(node.unroll); if (clone->unroll.GetResultingValue() == LoopUnroll::Always) throw AstError{ "unroll(always) is not yet supported on while" }; } @@ -1295,7 +1331,56 @@ namespace Nz::ShaderAst return &*it; } - Expression& SanitizeVisitor::MandatoryExpr(const ExpressionPtr& node) + template + auto SanitizeVisitor::FindIdentifier(const std::string_view& identifierName, F&& functor) const -> const Identifier* + { + auto it = std::find_if(m_context->identifiersInScope.rbegin(), m_context->identifiersInScope.rend(), [&](const Identifier& identifier) + { + return identifier.name == identifierName && functor(identifier); + }); + if (it == m_context->identifiersInScope.rend()) + return nullptr; + + return &*it; + } + + TypeParameter SanitizeVisitor::FindTypeParameter(const std::string_view& identifierName) const + { + const auto* identifier = FindIdentifier(identifierName); + if (!identifier) + throw std::runtime_error("identifier " + std::string(identifierName) + " not found"); + + switch (identifier->type) + { + case Identifier::Type::Constant: + return m_context->constantValues[identifier->index]; + + case Identifier::Type::Struct: + return StructType{ identifier->index }; + + case Identifier::Type::Type: + return std::visit([&](auto&& arg) -> TypeParameter + { + return arg; + }, m_context->types[identifier->index]); + + case Identifier::Type::Alias: + throw std::runtime_error("TODO"); + + case Identifier::Type::Function: + throw std::runtime_error("unexpected function identifier"); + + case Identifier::Type::Intrinsic: + throw std::runtime_error("unexpected intrinsic identifier"); + + case Identifier::Type::Variable: + throw std::runtime_error("unexpected variable identifier"); + } + + throw std::runtime_error("internal error"); + } + + Expression& SanitizeVisitor::MandatoryExpr(const ExpressionPtr& node) const { if (!node) throw AstError{ "Invalid expression" }; @@ -1303,7 +1388,7 @@ namespace Nz::ShaderAst return *node; } - Statement& SanitizeVisitor::MandatoryStatement(const StatementPtr& node) + Statement& SanitizeVisitor::MandatoryStatement(const StatementPtr& node) const { if (!node) throw AstError{ "Invalid statement" }; @@ -1313,14 +1398,16 @@ namespace Nz::ShaderAst void SanitizeVisitor::PushScope() { - m_context->scopeSizes.push_back(m_context->identifiersInScope.size()); + auto& scope = m_context->scopes.emplace_back(); + scope.previousSize = m_context->identifiersInScope.size(); } void SanitizeVisitor::PopScope() { - assert(!m_context->scopeSizes.empty()); - m_context->identifiersInScope.resize(m_context->scopeSizes.back()); - m_context->scopeSizes.pop_back(); + assert(!m_context->scopes.empty()); + auto& scope = m_context->scopes.back(); + m_context->identifiersInScope.resize(scope.previousSize); + m_context->scopes.pop_back(); } ExpressionPtr SanitizeVisitor::CacheResult(ExpressionPtr expression) @@ -1342,11 +1429,21 @@ namespace Nz::ShaderAst return varExpr; } + ConstantValue SanitizeVisitor::ComputeConstantValue(Expression& expr) const + { + // Run optimizer on constant value to hopefully retrieve a single constant value + ExpressionPtr optimizedExpr = Optimize(expr); + if (optimizedExpr->GetType() != NodeType::ConstantValueExpression) + throw AstError{"expected a constant expression"}; + + return static_cast(*optimizedExpr).value; + } + template - const T& SanitizeVisitor::ComputeAttributeValue(AttributeValue& attribute) + const T& SanitizeVisitor::ComputeExprValue(ExpressionValue& attribute) const { if (!attribute.HasValue()) - throw AstError{"attribute expected a value"}; + throw AstError{ "attribute expected a value" }; if (attribute.IsExpression()) { @@ -1372,18 +1469,8 @@ namespace Nz::ShaderAst return attribute.GetResultingValue(); } - ConstantValue SanitizeVisitor::ComputeConstantValue(Expression& expr) - { - // Run optimizer on constant value to hopefully retrieve a single constant value - ExpressionPtr optimizedExpr = Optimize(expr); - if (optimizedExpr->GetType() != NodeType::ConstantValueExpression) - throw AstError{"expected a constant expression"}; - - return static_cast(*optimizedExpr).value; - } - template - std::unique_ptr SanitizeVisitor::Optimize(T& node) + std::unique_ptr SanitizeVisitor::Optimize(T& node) const { AstOptimizer::Options optimizerOptions; optimizerOptions.constantQueryCallback = [this](std::size_t constantId) -> const ConstantValue& @@ -1396,28 +1483,180 @@ namespace Nz::ShaderAst return static_unique_pointer_cast(ShaderAst::Optimize(node, optimizerOptions)); } - std::size_t SanitizeVisitor::DeclareFunction(DeclareFunctionStatement& funcDecl) - { - std::size_t functionIndex = m_context->functions.size(); - auto& funcData = m_context->functions.emplace_back(); - funcData.node = &funcDecl; - - return functionIndex; - } - void SanitizeVisitor::PropagateFunctionFlags(std::size_t funcIndex, FunctionFlags flags, Bitset<>& seen) { assert(funcIndex < m_context->functions.size()); auto& funcData = m_context->functions[funcIndex]; - if (!funcData.defined) - return; - funcData.flags |= flags; for (std::size_t i = funcData.calledByFunctions.FindFirst(); i != funcData.calledByFunctions.npos; i = funcData.calledByFunctions.FindNext(i)) PropagateFunctionFlags(i, funcData.flags, seen); } + void SanitizeVisitor::RegisterBuiltin() + { + // Primitive types + RegisterType("bool", PrimitiveType::Boolean); + RegisterType("f32", PrimitiveType::Float32); + RegisterType("i32", PrimitiveType::Int32); + RegisterType("u32", PrimitiveType::UInt32); + + // Partial types + + // Array + RegisterType("array", PartialType { + { TypeParameterCategory::FullType, TypeParameterCategory::ConstantValue }, + [=](const TypeParameter* parameters, std::size_t parameterCount) -> ExpressionType + { + assert(parameterCount == 2); + assert(std::holds_alternative(parameters[0])); + assert(std::holds_alternative(parameters[1])); + + const ExpressionType& exprType = std::get(parameters[0]); + const ConstantValue& length = std::get(parameters[1]); + + UInt32 lengthValue; + if (std::holds_alternative(length)) + { + Int32 value = std::get(length); + if (value <= 0) + throw AstError{ "array length must a positive integer" }; + + lengthValue = SafeCast(value); + } + else if (std::holds_alternative(length)) + { + lengthValue = std::get(length); + if (lengthValue == 0) + throw AstError{ "array length must a positive integer" }; + } + else + throw AstError{ "array length must a positive integer" }; + + ArrayType arrayType; + arrayType.containedType = std::make_unique(); + arrayType.containedType->type = exprType; + arrayType.length = lengthValue; + + return arrayType; + } + }); + + // matX + for (std::size_t componentCount = 2; componentCount <= 4; ++componentCount) + { + RegisterType("mat" + std::to_string(componentCount), PartialType { + { TypeParameterCategory::PrimitiveType }, + [=](const TypeParameter* parameters, std::size_t parameterCount) -> ExpressionType + { + assert(parameterCount == 1); + assert(std::holds_alternative(*parameters)); + + const ExpressionType& exprType = std::get(*parameters); + assert(IsPrimitiveType(exprType)); + + return MatrixType { + componentCount, componentCount, std::get(exprType) + }; + } + }); + } + + // vecX + for (std::size_t componentCount = 2; componentCount <= 4; ++componentCount) + { + RegisterType("vec" + std::to_string(componentCount), PartialType { + { TypeParameterCategory::PrimitiveType }, + [=](const TypeParameter* parameters, std::size_t parameterCount) -> ExpressionType + { + assert(parameterCount == 1); + assert(std::holds_alternative(*parameters)); + + const ExpressionType& exprType = std::get(*parameters); + assert(IsPrimitiveType(exprType)); + + return VectorType { + componentCount, std::get(exprType) + }; + } + }); + } + + // samplers + struct SamplerInfo + { + std::string typeName; + ImageType imageType; + }; + + std::array samplerInfos = { + { + { + "sampler2D", + ImageType::E2D + }, + { + "samplerCube", + ImageType::Cubemap + } + } + }; + + for (SamplerInfo& sampler : samplerInfos) + { + RegisterType(std::move(sampler.typeName), PartialType { + { TypeParameterCategory::PrimitiveType }, + [=](const TypeParameter* parameters, std::size_t parameterCount) -> ExpressionType + { + assert(parameterCount == 1); + assert(std::holds_alternative(*parameters)); + + const ExpressionType& exprType = std::get(*parameters); + assert(IsPrimitiveType(exprType)); + + PrimitiveType primitiveType = std::get(exprType); + + // TODO: Add support for integer samplers + if (primitiveType != PrimitiveType::Float32) + throw AstError{ "for now only f32 samplers are supported" }; + + return SamplerType { + sampler.imageType, primitiveType + }; + } + }); + } + + // uniform + RegisterType("uniform", PartialType { + { TypeParameterCategory::StructType }, + [=](const TypeParameter* parameters, std::size_t parameterCount) -> ExpressionType + { + assert(parameterCount == 1); + assert(std::holds_alternative(*parameters)); + + const ExpressionType& exprType = std::get(*parameters); + assert(IsStructType(exprType)); + + StructType structType = std::get(exprType); + return UniformType { + structType + }; + } + }); + + // Intrinsics + RegisterIntrinsic("cross", IntrinsicType::CrossProduct); + RegisterIntrinsic("dot", IntrinsicType::DotProduct); + RegisterIntrinsic("exp", IntrinsicType::Exp); + RegisterIntrinsic("length", IntrinsicType::Length); + RegisterIntrinsic("max", IntrinsicType::Max); + RegisterIntrinsic("min", IntrinsicType::Min); + RegisterIntrinsic("normalize", IntrinsicType::Normalize); + RegisterIntrinsic("pow", IntrinsicType::Pow); + RegisterIntrinsic("reflect", IntrinsicType::Reflect); + } + std::size_t SanitizeVisitor::RegisterConstant(std::string name, ConstantValue value) { if (FindIdentifier(name)) @@ -1435,14 +1674,9 @@ namespace Nz::ShaderAst return constantIndex; } - auto SanitizeVisitor::RegisterFunction(std::size_t functionIndex) -> FunctionData& + std::size_t SanitizeVisitor::RegisterFunction(std::string name, FunctionData funcData) { - assert(m_context->functions.size() >= functionIndex); - auto& funcData = m_context->functions[functionIndex]; - assert(!funcData.defined); - funcData.defined = true; - - if (auto* identifier = FindIdentifier(funcData.node->name)) + if (auto* identifier = FindIdentifier(name)) { bool duplicate = true; @@ -1458,13 +1692,17 @@ namespace Nz::ShaderAst throw AstError{ funcData.node->name + " is already used" }; } + std::size_t functionIndex = m_context->functions.size(); + + m_context->functions.emplace_back(std::move(funcData)); + m_context->identifiersInScope.push_back({ - funcData.node->name, + std::move(name), functionIndex, Identifier::Type::Function }); - return funcData; + return functionIndex; } std::size_t SanitizeVisitor::RegisterIntrinsic(std::string name, IntrinsicType type) @@ -1501,11 +1739,48 @@ namespace Nz::ShaderAst return structIndex; } + std::size_t SanitizeVisitor::RegisterType(std::string name, ExpressionType expressionType) + { + if (FindIdentifier(name)) + throw AstError{ name + " is already used" }; + + std::size_t typeIndex = m_context->types.size(); + m_context->types.emplace_back(std::move(expressionType)); + + m_context->identifiersInScope.push_back({ + std::move(name), + typeIndex, + Identifier::Type::Type + }); + + return typeIndex; + } + + std::size_t SanitizeVisitor::RegisterType(std::string name, PartialType partialType) + { + if (FindIdentifier(name)) + throw AstError{ name + " is already used" }; + + std::size_t typeIndex = m_context->types.size(); + m_context->types.emplace_back(std::move(partialType)); + + m_context->identifiersInScope.push_back({ + std::move(name), + typeIndex, + Identifier::Type::Type + }); + + return typeIndex; + } + std::size_t SanitizeVisitor::RegisterVariable(std::string name, ExpressionType type) { - // Allow variable shadowing - if (auto* identifier = FindIdentifier(name); identifier && identifier->type != Identifier::Type::Variable) - throw AstError{ name + " is already used" }; + if (auto* identifier = FindIdentifier(name)) + { + // Allow variable shadowing + if (identifier->type != Identifier::Type::Variable) + throw AstError{ name + " is already used" }; + } std::size_t varIndex = m_context->variableTypes.size(); m_context->variableTypes.emplace_back(std::move(type)); @@ -1521,7 +1796,46 @@ namespace Nz::ShaderAst void SanitizeVisitor::ResolveFunctions() { - // Once every function is known, we can propagate flags + // Once every function is known, we can evaluate function content + for (auto& pendingFunc : m_context->pendingFunctions) + { + PushScope(); + + for (auto& parameter : pendingFunc.cloneNode->parameters) + { + std::size_t varIndex = RegisterVariable(parameter.name, parameter.type.GetResultingValue()); + if (!pendingFunc.cloneNode->varIndex) + pendingFunc.cloneNode->varIndex = varIndex; //< First parameter variable index is node variable index + + SanitizeIdentifier(parameter.name); + } + + Context::CurrentFunctionData tempFuncData; + if (pendingFunc.cloneNode->entryStage.HasValue()) + tempFuncData.stageType = pendingFunc.cloneNode->entryStage.GetResultingValue(); + + m_context->currentFunction = &tempFuncData; + + std::vector* previousList = m_context->currentStatementList; + m_context->currentStatementList = &pendingFunc.cloneNode->statements; + + pendingFunc.cloneNode->statements.reserve(pendingFunc.node->statements.size()); + for (auto& statement : pendingFunc.node->statements) + pendingFunc.cloneNode->statements.push_back(CloneStatement(MandatoryStatement(statement))); + + m_context->currentStatementList = previousList; + m_context->currentFunction = nullptr; + + std::size_t funcIndex = *pendingFunc.cloneNode->funcIndex; + for (std::size_t i = tempFuncData.calledFunctions.FindFirst(); i != tempFuncData.calledFunctions.npos; i = tempFuncData.calledFunctions.FindNext(i)) + { + assert(i < m_context->functions.size()); + auto& targetFunc = m_context->functions[i]; + targetFunc.calledByFunctions.UnboundedSet(funcIndex); + } + + PopScope(); + } Bitset<> seen; for (std::size_t funcIndex = 0; funcIndex < m_context->functions.size(); ++funcIndex) @@ -1539,6 +1853,23 @@ namespace Nz::ShaderAst } } + const ExpressionPtr& SanitizeVisitor::ResolveCondExpression(ConditionalExpression& node) + { + MandatoryExpr(node.condition); + MandatoryExpr(node.truePath); + MandatoryExpr(node.falsePath); + + ConstantValue conditionValue = ComputeConstantValue(*AstCloner::Clone(*node.condition)); + if (GetExpressionType(conditionValue) != ExpressionType{ PrimitiveType::Boolean }) + throw AstError{ "expected a boolean value" }; + + if (std::get(conditionValue)) + return node.truePath; + else + return node.falsePath; + + } + std::size_t SanitizeVisitor::ResolveStruct(const ExpressionType& exprType) { return std::visit([&](auto&& arg) -> std::size_t @@ -1549,9 +1880,13 @@ namespace Nz::ShaderAst return ResolveStruct(arg); else if constexpr (std::is_same_v || std::is_same_v || + std::is_same_v || + std::is_same_v || std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v) { throw AstError{ "expression is not a structure" }; @@ -1580,72 +1915,40 @@ namespace Nz::ShaderAst std::size_t SanitizeVisitor::ResolveStruct(const UniformType& uniformType) { - return std::visit([&](auto&& arg) -> std::size_t - { - using T = std::decay_t; - - if constexpr (std::is_same_v || std::is_same_v) - return ResolveStruct(arg); - else - static_assert(AlwaysFalse::value, "non-exhaustive visitor"); - }, uniformType.containedType); + return uniformType.containedType.structIndex; } ExpressionType SanitizeVisitor::ResolveType(const ExpressionType& exprType) { - return std::visit([&](auto&& arg) -> ExpressionType - { - using T = std::decay_t; + if (!IsTypeExpression(exprType)) + return exprType; - if constexpr (std::is_same_v || - std::is_same_v || - std::is_same_v || - std::is_same_v || - std::is_same_v || - std::is_same_v) - { - return exprType; - } - else if constexpr (std::is_same_v) - { - ArrayType resolvedArrayType; - if (arg.length.IsExpression()) - { - resolvedArrayType.length = CloneExpression(arg.length.GetExpression()); - ComputeAttributeValue(resolvedArrayType.length); - } - else if (arg.length.IsResultingValue()) - resolvedArrayType.length = arg.length.GetResultingValue(); + std::size_t typeIndex = std::get(exprType).typeIndex; - resolvedArrayType.containedType = std::make_unique(); - resolvedArrayType.containedType->type = ResolveType(arg.containedType->type); + const auto& type = m_context->types[typeIndex]; + if (std::holds_alternative(type)) + throw AstError{ "full type expected" }; - return resolvedArrayType; - } - else if constexpr (std::is_same_v) - { - const Identifier* identifier = FindIdentifier(arg.name); - if (!identifier) - throw AstError{ "unknown identifier " + arg.name }; + return std::get(type); + } - if (identifier->type != Identifier::Type::Struct) - throw AstError{ "expected type identifier" }; + ExpressionType SanitizeVisitor::ResolveType(const ExpressionValue& exprTypeValue) + { + if (!exprTypeValue.HasValue()) + return {}; - return StructType{ identifier->index }; - } - else if constexpr (std::is_same_v) - { - return std::visit([&](auto&& containedArg) - { - ExpressionType resolvedType = ResolveType(containedArg); - assert(std::holds_alternative(resolvedType)); + if (exprTypeValue.IsResultingValue()) + return ResolveType(exprTypeValue.GetResultingValue()); - return UniformType{ std::get(resolvedType) }; - }, arg.containedType); - } - else - static_assert(AlwaysFalse::value, "non-exhaustive visitor"); - }, exprType); + assert(exprTypeValue.IsExpression()); + ExpressionPtr expression = CloneExpression(exprTypeValue.GetExpression()); + assert(expression->cachedExpressionType); + + const ExpressionType& exprType = expression->cachedExpressionType.value(); + //if (!IsTypeType(exprType)) + // throw AstError{ "type expected" }; + + return ResolveType(exprType); } void SanitizeVisitor::SanitizeIdentifier(std::string& identifier) @@ -1661,6 +1964,27 @@ namespace Nz::ShaderAst } } + void SanitizeVisitor::TypeMustMatch(const ExpressionPtr& left, const ExpressionPtr& right) const + { + return TypeMustMatch(GetExpressionType(*left), GetExpressionType(*right)); + } + + void SanitizeVisitor::TypeMustMatch(const ExpressionType& left, const ExpressionType& right) const + { + if (left != right) + throw AstError{ "Left expression type must match right expression type" }; + } + + StatementPtr SanitizeVisitor::Unscope(StatementPtr node) + { + assert(node); + + if (node->GetType() == NodeType::ScopedStatement) + return std::move(static_cast(*node).statement); + else + return node; + } + void SanitizeVisitor::Validate(WhileStatement& node) { if (GetExpressionType(*node.condition) != ExpressionType{ PrimitiveType::Boolean }) @@ -1669,65 +1993,129 @@ namespace Nz::ShaderAst void SanitizeVisitor::Validate(AccessIndexExpression& node) { - if (node.indices.empty()) - throw AstError{ "AccessIndexExpression must have at least one index" }; - - for (auto& index : node.indices) - { - const ShaderAst::ExpressionType& indexType = GetExpressionType(*index); - if (!IsPrimitiveType(indexType)) - throw AstError{ "AccessIndex expects integer indices" }; - - PrimitiveType primitiveIndexType = std::get(indexType); - if (primitiveIndexType != PrimitiveType::Int32 && primitiveIndexType != PrimitiveType::UInt32) - throw AstError{ "AccessIndex expects integer indices" }; - } - ExpressionType exprType = GetExpressionType(*node.expr); - for (const auto& indexExpr : node.indices) + if (IsTypeExpression(exprType)) { - if (IsArrayType(exprType)) + std::size_t typeIndex = std::get(exprType).typeIndex; + const auto& type = m_context->types[typeIndex]; + + if (!std::holds_alternative(type)) + throw std::runtime_error("only partial types can be specialized"); + + const PartialType& partialType = std::get(type); + if (partialType.parameters.size() != node.indices.size()) + throw std::runtime_error("parameter count mismatch"); + + StackVector parameters = NazaraStackVector(TypeParameter, partialType.parameters.size()); + for (std::size_t i = 0; i < partialType.parameters.size(); ++i) { - const ArrayType& arrayType = std::get(exprType); - ExpressionType containedType = arrayType.containedType->type; //< Don't overwrite exprType directly since it contains arrayType - exprType = std::move(containedType); + ExpressionPtr indexExpr = CloneExpression(node.indices[i]); + switch (partialType.parameters[i]) + { + case TypeParameterCategory::ConstantValue: + { + parameters.push_back(ComputeConstantValue(*indexExpr)); + break; + } + + case TypeParameterCategory::FullType: + case TypeParameterCategory::PrimitiveType: + case TypeParameterCategory::StructType: + { + ExpressionType resolvedType = ResolveType(GetExpressionType(*indexExpr)); + + switch (partialType.parameters[i]) + { + case TypeParameterCategory::PrimitiveType: + { + if (!IsPrimitiveType(resolvedType)) + throw std::runtime_error("expected a primitive type"); + + break; + } + + case TypeParameterCategory::StructType: + { + if (!IsStructType(resolvedType)) + throw std::runtime_error("expected a struct type"); + + break; + } + + default: + break; + } + + parameters.push_back(resolvedType); + break; + } + } } - else if (IsStructType(exprType)) - { - const ShaderAst::ExpressionType& indexType = GetExpressionType(*indexExpr); - if (indexExpr->GetType() != NodeType::ConstantValueExpression || indexType != ExpressionType{ PrimitiveType::Int32 }) - throw AstError{ "struct can only be accessed with constant i32 indices" }; - ConstantValueExpression& constantExpr = static_cast(*indexExpr); - - Int32 index = std::get(constantExpr.value); - - std::size_t structIndex = ResolveStruct(exprType); - assert(structIndex < m_context->structs.size()); - const StructDescription* s = m_context->structs[structIndex]; - - exprType = ResolveType(s->members[index].type); - } - else if (IsMatrixType(exprType)) - { - // Matrix index (ex: mat[2]) - MatrixType matrixType = std::get(exprType); - - //TODO: Handle row-major matrices - exprType = VectorType{ matrixType.rowCount, matrixType.type }; - } - else if (IsVectorType(exprType)) - { - // Swizzle expression with one component (ex: vec[2]) - VectorType swizzledVec = std::get(exprType); - - exprType = swizzledVec.type; - } - else - throw AstError{ "unexpected type (only struct, vectors and matrices can be indexed)" }; //< TODO: Add support for arrays + assert(parameters.size() == partialType.parameters.size()); + node.cachedExpressionType = partialType.buildFunc(parameters.data(), parameters.size()); } + else + { + if (node.indices.size() != 1) + throw AstError{ "AccessIndexExpression must have at one index" }; - node.cachedExpressionType = std::move(exprType); + for (auto& index : node.indices) + { + const ShaderAst::ExpressionType& indexType = GetExpressionType(*index); + if (!IsPrimitiveType(indexType)) + throw AstError{ "AccessIndex expects integer indices" }; + + PrimitiveType primitiveIndexType = std::get(indexType); + if (primitiveIndexType != PrimitiveType::Int32 && primitiveIndexType != PrimitiveType::UInt32) + throw AstError{ "AccessIndex expects integer indices" }; + } + + for (const auto& indexExpr : node.indices) + { + if (IsArrayType(exprType)) + { + const ArrayType& arrayType = std::get(exprType); + ExpressionType containedType = arrayType.containedType->type; //< Don't overwrite exprType directly since it contains arrayType + exprType = std::move(containedType); + } + else if (IsStructType(exprType)) + { + const ShaderAst::ExpressionType& indexType = GetExpressionType(*indexExpr); + if (indexExpr->GetType() != NodeType::ConstantValueExpression || indexType != ExpressionType{ PrimitiveType::Int32 }) + throw AstError{ "struct can only be accessed with constant i32 indices" }; + + ConstantValueExpression& constantExpr = static_cast(*indexExpr); + + Int32 index = std::get(constantExpr.value); + + std::size_t structIndex = ResolveStruct(exprType); + assert(structIndex < m_context->structs.size()); + const StructDescription* s = m_context->structs[structIndex]; + + exprType = ResolveType(s->members[index].type); + } + else if (IsMatrixType(exprType)) + { + // Matrix index (ex: mat[2]) + MatrixType matrixType = std::get(exprType); + + //TODO: Handle row-major matrices + exprType = VectorType{ matrixType.rowCount, matrixType.type }; + } + else if (IsVectorType(exprType)) + { + // Swizzle expression with one component (ex: vec[2]) + VectorType swizzledVec = std::get(exprType); + + exprType = swizzledVec.type; + } + else + throw AstError{ "unexpected type (only struct, vectors and matrices can be indexed)" }; //< TODO: Add support for arrays + } + + node.cachedExpressionType = std::move(exprType); + } } void SanitizeVisitor::Validate(AssignExpression& node) @@ -1771,35 +2159,43 @@ namespace Nz::ShaderAst node.cachedExpressionType = ValidateBinaryOp(node.op, node.left, node.right); } - void SanitizeVisitor::Validate(CallFunctionExpression& node, const DeclareFunctionStatement* referenceDeclaration) + void SanitizeVisitor::Validate(CallFunctionExpression& node) { + const ShaderAst::ExpressionType& targetFuncType = GetExpressionType(*node.targetFunction); + assert(std::holds_alternative(targetFuncType)); + + std::size_t targetFuncIndex = std::get(targetFuncType).funcIndex; + assert(targetFuncIndex < m_context->functions.size()); + auto& funcData = m_context->functions[targetFuncIndex]; + + const DeclareFunctionStatement* referenceDeclaration = funcData.node; + if (referenceDeclaration->entryStage.HasValue()) throw AstError{ referenceDeclaration->name + " is an entry function which cannot be called by the program" }; for (std::size_t i = 0; i < node.parameters.size(); ++i) { - if (GetExpressionType(*node.parameters[i]) != referenceDeclaration->parameters[i].type) + if (GetExpressionType(*node.parameters[i]) != referenceDeclaration->parameters[i].type.GetResultingValue()) throw AstError{ "function " + referenceDeclaration->name + " parameter " + std::to_string(i) + " type mismatch" }; } if (node.parameters.size() != referenceDeclaration->parameters.size()) throw AstError{ "function " + referenceDeclaration->name + " expected " + std::to_string(referenceDeclaration->parameters.size()) + " parameters, got " + std::to_string(node.parameters.size()) }; - node.cachedExpressionType = referenceDeclaration->returnType; + node.cachedExpressionType = referenceDeclaration->returnType.GetResultingValue(); } void SanitizeVisitor::Validate(CastExpression& node) { - node.targetType = ResolveType(node.targetType); - node.cachedExpressionType = node.targetType; + ExpressionType resolvedType = ResolveType(node.targetType); const auto& firstExprPtr = node.expressions.front(); if (!firstExprPtr) throw AstError{ "expected at least one expression" }; - if (IsMatrixType(node.targetType)) + if (IsMatrixType(resolvedType)) { - const MatrixType& targetMatrixType = std::get(node.targetType); + const MatrixType& targetMatrixType = std::get(resolvedType); const ExpressionType& firstExprType = GetExpressionType(*firstExprPtr); if (IsMatrixType(firstExprType)) @@ -1808,7 +2204,6 @@ namespace Nz::ShaderAst throw AstError{ "too many expressions" }; // Matrix to matrix cast: always valid - return; } else { @@ -1829,48 +2224,54 @@ namespace Nz::ShaderAst } } } - - auto GetComponentCount = [](const ExpressionType& exprType) -> std::size_t + else { - if (IsVectorType(exprType)) - return std::get(exprType).componentCount; - else + auto GetComponentCount = [](const ExpressionType& exprType) -> std::size_t { - assert(IsPrimitiveType(exprType)); - return 1; + if (IsVectorType(exprType)) + return std::get(exprType).componentCount; + else + { + assert(IsPrimitiveType(exprType)); + return 1; + } + }; + + std::size_t componentCount = 0; + std::size_t requiredComponents = GetComponentCount(resolvedType); + + for (auto& exprPtr : node.expressions) + { + if (!exprPtr) + break; + + const ExpressionType& exprType = GetExpressionType(*exprPtr); + if (!IsPrimitiveType(exprType) && !IsVectorType(exprType)) + throw AstError{ "incompatible type" }; + + componentCount += GetComponentCount(exprType); } - }; - std::size_t componentCount = 0; - std::size_t requiredComponents = GetComponentCount(node.targetType); - - for (auto& exprPtr : node.expressions) - { - if (!exprPtr) - break; - - const ExpressionType& exprType = GetExpressionType(*exprPtr); - if (!IsPrimitiveType(exprType) && !IsVectorType(exprType)) - throw AstError{ "incompatible type" }; - - componentCount += GetComponentCount(exprType); + if (componentCount != requiredComponents) + throw AstError{ "component count doesn't match required component count" }; } - if (componentCount != requiredComponents) - throw AstError{ "component count doesn't match required component count" }; + node.cachedExpressionType = resolvedType; + node.targetType = std::move(resolvedType); } void SanitizeVisitor::Validate(DeclareVariableStatement& node) { - if (IsNoType(node.varType)) + ExpressionType resolvedType; + if (!node.varType.HasValue()) { if (!node.initialExpression) throw AstError{ "variable must either have a type or an initial value" }; - node.varType = ResolveType(GetExpressionType(*node.initialExpression)); + resolvedType = ResolveType(GetExpressionType(*node.initialExpression)); } else - node.varType = ResolveType(node.varType); + resolvedType = ResolveType(node.varType); if (m_context->options.makeVariableNameUnique && FindIdentifier(node.varName) != nullptr) { @@ -1886,7 +2287,8 @@ namespace Nz::ShaderAst node.varName = std::move(candidateName); } - node.varIndex = RegisterVariable(node.varName, node.varType); + node.varIndex = RegisterVariable(node.varName, resolvedType); + node.varType = std::move(resolvedType); SanitizeIdentifier(node.varName); } @@ -2299,15 +2701,4 @@ namespace Nz::ShaderAst throw AstError{ "internal error: unchecked operation" }; } - - void SanitizeVisitor::TypeMustMatch(const ExpressionPtr& left, const ExpressionPtr& right) - { - return TypeMustMatch(GetExpressionType(*left), GetExpressionType(*right)); - } - - void SanitizeVisitor::TypeMustMatch(const ExpressionType& left, const ExpressionType& right) - { - if (left != right) - throw AstError{ "Left expression type must match right expression type" }; - } } diff --git a/src/Nazara/Shader/GlslWriter.cpp b/src/Nazara/Shader/GlslWriter.cpp index b7207e80d..315d6b31a 100644 --- a/src/Nazara/Shader/GlslWriter.cpp +++ b/src/Nazara/Shader/GlslWriter.cpp @@ -43,7 +43,7 @@ namespace Nz AstRecursiveVisitor::Visit(node); assert(currentFunction); - currentFunction->calledFunctions.UnboundedSet(std::get(node.targetFunction)); + currentFunction->calledFunctions.UnboundedSet(std::get(GetExpressionType(*node.targetFunction)).funcIndex); } void Visit(ShaderAst::ConditionalExpression& /*node*/) override @@ -227,6 +227,24 @@ namespace Nz throw std::runtime_error("unexpected ArrayType"); } + void GlslWriter::Append(ShaderAst::BuiltinEntry builtin) + { + switch (builtin) + { + case ShaderAst::BuiltinEntry::FragCoord: + Append("gl_FragCoord"); + break; + + case ShaderAst::BuiltinEntry::FragDepth: + Append("gl_FragDepth"); + break; + + case ShaderAst::BuiltinEntry::VertexPosition: + Append("gl_Position"); + break; + } + } + void GlslWriter::Append(const ShaderAst::ExpressionType& type) { std::visit([&](auto&& arg) @@ -235,22 +253,14 @@ namespace Nz }, type); } - void GlslWriter::Append(ShaderAst::BuiltinEntry builtin) + void GlslWriter::Append(const ShaderAst::ExpressionValue& type) { - switch (builtin) - { - case ShaderAst::BuiltinEntry::FragCoord: - Append("gl_FragCoord"); - break; + Append(type.GetResultingValue()); + } - case ShaderAst::BuiltinEntry::FragDepth: - Append("gl_FragDepth"); - break; - - case ShaderAst::BuiltinEntry::VertexPosition: - Append("gl_Position"); - break; - } + void GlslWriter::Append(const ShaderAst::FunctionType& /*functionType*/) + { + throw std::runtime_error("unexpected function type"); } void GlslWriter::Append(const ShaderAst::IdentifierType& /*identifierType*/) @@ -258,6 +268,11 @@ namespace Nz throw std::runtime_error("unexpected identifier type"); } + void GlslWriter::Append(const ShaderAst::IntrinsicFunctionType& /*intrinsicFunctionType*/) + { + throw std::runtime_error("unexpected intrinsic function type"); + } + void GlslWriter::Append(const ShaderAst::MatrixType& matrixType) { if (matrixType.columnCount == matrixType.rowCount) @@ -274,6 +289,11 @@ namespace Nz } } + void GlslWriter::Append(const ShaderAst::MethodType& methodType) + { + throw std::runtime_error("unexpected method type"); + } + void GlslWriter::Append(ShaderAst::PrimitiveType type) { switch (type) @@ -316,6 +336,11 @@ namespace Nz Append(structDesc->name); } + void GlslWriter::Append(const ShaderAst::Type& /*type*/) + { + throw std::runtime_error("unexpected Type"); + } + void GlslWriter::Append(const ShaderAst::UniformType& /*uniformType*/) { throw std::runtime_error("unexpected UniformType"); @@ -386,7 +411,7 @@ namespace Nz first = false; - AppendVariableDeclaration(parameter.type, parameter.name); + AppendVariableDeclaration(parameter.type.GetResultingValue(), parameter.name); } AppendLine((forward) ? ");" : ")"); } @@ -506,13 +531,13 @@ namespace Nz { if (ShaderAst::IsArrayType(varType)) { - std::vector*> lengths; + std::vector lengths; const ShaderAst::ExpressionType* exprType = &varType; while (ShaderAst::IsArrayType(*exprType)) { const auto& arrayType = std::get(*exprType); - lengths.push_back(&arrayType.length); + lengths.push_back(arrayType.length); exprType = &arrayType.containedType->type; } @@ -520,17 +545,8 @@ namespace Nz assert(!ShaderAst::IsArrayType(*exprType)); Append(*exprType, " ", varName); - for (const auto* lengthAttribute : lengths) - { - Append("["); - - if (lengthAttribute->IsResultingValue()) - Append(lengthAttribute->GetResultingValue()); - else - lengthAttribute->GetExpression()->Visit(*this); - - Append("]"); - } + for (UInt32 lengthAttribute : lengths) + Append("[", lengthAttribute, "]"); } else Append(varType, " ", varName); @@ -582,8 +598,8 @@ namespace Nz const std::string& varName = parameter.name; RegisterVariable(*node.varIndex, varName); - assert(IsStructType(parameter.type)); - std::size_t structIndex = std::get(parameter.type).structIndex; + assert(IsStructType(parameter.type.GetResultingValue())); + std::size_t structIndex = std::get(parameter.type.GetResultingValue()).structIndex; const ShaderAst::StructDescription* structDesc = Retrieve(m_currentState->structs, structIndex); AppendLine(structDesc->name, " ", varName, ";"); @@ -631,7 +647,7 @@ namespace Nz Append("layout(location = "); Append(member.locationIndex.GetResultingValue()); Append(") ", keyword, " "); - AppendVariableDeclaration(member.type, targetPrefix + member.name); + AppendVariableDeclaration(member.type.GetResultingValue(), targetPrefix + member.name); AppendLine(";"); fields.push_back({ @@ -651,9 +667,9 @@ namespace Nz { assert(node.parameters.size() == 1); auto& parameter = node.parameters.front(); - assert(std::holds_alternative(parameter.type)); + assert(std::holds_alternative(parameter.type.GetResultingValue())); - std::size_t inputStructIndex = std::get(parameter.type).structIndex; + std::size_t inputStructIndex = std::get(parameter.type.GetResultingValue()).structIndex; inputStruct = Retrieve(m_currentState->structs, inputStructIndex); AppendCommentSection("Inputs"); @@ -666,10 +682,10 @@ namespace Nz AppendLine(); } - if (!IsNoType(node.returnType)) + if (node.returnType.HasValue()) { - assert(std::holds_alternative(node.returnType)); - std::size_t outputStructIndex = std::get(node.returnType).structIndex; + assert(std::holds_alternative(node.returnType.GetResultingValue())); + std::size_t outputStructIndex = std::get(node.returnType.GetResultingValue()).structIndex; const ShaderAst::StructDescription* outputStruct = Retrieve(m_currentState->structs, outputStructIndex); @@ -690,6 +706,18 @@ namespace Nz m_currentState->variableNames.emplace(varIndex, std::move(varName)); } + void GlslWriter::ScopeVisit(ShaderAst::Statement& node) + { + if (node.GetType() != ShaderAst::NodeType::ScopedStatement) + { + EnterScope(); + node.Visit(*this); + LeaveScope(true); + } + else + node.Visit(*this); + } + void GlslWriter::Visit(ShaderAst::ExpressionPtr& expr, bool encloseIfRequired) { bool enclose = encloseIfRequired && (GetExpressionCategory(*expr) != ShaderAst::ExpressionCategory::LValue); @@ -722,12 +750,10 @@ namespace Nz assert(!IsStructType(exprType)); // Array access - for (ShaderAst::ExpressionPtr& expr : node.indices) - { - Append("["); - Visit(expr); - Append("]"); - } + assert(node.indices.size() == 1); + Append("["); + Visit(node.indices.front()); + Append("]"); } void GlslWriter::Visit(ShaderAst::AssignExpression& node) @@ -775,8 +801,8 @@ namespace Nz void GlslWriter::Visit(ShaderAst::CallFunctionExpression& node) { - assert(std::holds_alternative(node.targetFunction)); - const std::string& targetName = Retrieve(m_currentState->previsitor.functions, std::get(node.targetFunction)).name; + std::size_t functionIndex = std::get(GetExpressionType(*node.targetFunction)).funcIndex; + const std::string& targetName = Retrieve(m_currentState->previsitor.functions, functionIndex).name; Append(targetName, "("); for (std::size_t i = 0; i < node.parameters.size(); ++i) @@ -946,9 +972,7 @@ namespace Nz statement.condition->Visit(*this); AppendLine(")"); - EnterScope(); - statement.statement->Visit(*this); - LeaveScope(); + ScopeVisit(*statement.statement); first = false; } @@ -957,9 +981,7 @@ namespace Nz { AppendLine("else"); - EnterScope(); - node.elseStatement->Visit(*this); - LeaveScope(); + ScopeVisit(*node.elseStatement); } } @@ -976,13 +998,10 @@ namespace Nz for (const auto& externalVar : node.externalVars) { bool isStd140 = false; - if (IsUniformType(externalVar.type)) + if (IsUniformType(externalVar.type.GetResultingValue())) { - auto& uniform = std::get(externalVar.type); - assert(std::holds_alternative(uniform.containedType)); - - std::size_t structIndex = std::get(uniform.containedType).structIndex; - ShaderAst::StructDescription* structInfo = Retrieve(m_currentState->structs, structIndex); + auto& uniform = std::get(externalVar.type.GetResultingValue()); + ShaderAst::StructDescription* structInfo = Retrieve(m_currentState->structs, uniform.containedType.structIndex); if (structInfo->layout.HasValue()) isStd140 = structInfo->layout.GetResultingValue() == StructLayout::Std140; } @@ -1018,18 +1037,15 @@ namespace Nz Append("uniform "); - if (IsUniformType(externalVar.type)) + if (IsUniformType(externalVar.type.GetResultingValue())) { Append("_NzBinding_"); AppendLine(externalVar.name); EnterScope(); { - auto& uniform = std::get(externalVar.type); - assert(std::holds_alternative(uniform.containedType)); - - std::size_t structIndex = std::get(uniform.containedType).structIndex; - auto& structDesc = Retrieve(m_currentState->structs, structIndex); + auto& uniform = std::get(externalVar.type.GetResultingValue()); + auto& structDesc = Retrieve(m_currentState->structs, uniform.containedType.structIndex); bool first = true; for (const auto& member : structDesc->members) @@ -1042,7 +1058,7 @@ namespace Nz first = false; - AppendVariableDeclaration(member.type, member.name); + AppendVariableDeclaration(member.type.GetResultingValue(), member.name); Append(";"); } } @@ -1052,11 +1068,11 @@ namespace Nz Append(externalVar.name); } else - AppendVariableDeclaration(externalVar.type, externalVar.name); + AppendVariableDeclaration(externalVar.type.GetResultingValue(), externalVar.name); AppendLine(";"); - if (IsUniformType(externalVar.type)) + if (IsUniformType(externalVar.type.GetResultingValue())) AppendLine(); RegisterVariable(varIndex++, externalVar.name); @@ -1138,7 +1154,7 @@ namespace Nz first = false; - AppendVariableDeclaration(member.type, member.name); + AppendVariableDeclaration(member.type.GetResultingValue(), member.name); Append(";"); } } @@ -1151,7 +1167,7 @@ namespace Nz assert(node.varIndex); RegisterVariable(*node.varIndex, node.varName); - AppendVariableDeclaration(node.varType, node.varName); + AppendVariableDeclaration(node.varType.GetResultingValue(), node.varName); if (node.initialExpression) { Append(" = "); @@ -1239,15 +1255,20 @@ namespace Nz } } + void GlslWriter::Visit(ShaderAst::ScopedStatement& node) + { + EnterScope(); + node.statement->Visit(*this); + LeaveScope(true); + } + void GlslWriter::Visit(ShaderAst::WhileStatement& node) { Append("while ("); node.condition->Visit(*this); AppendLine(")"); - EnterScope(); - node.body->Visit(*this); - LeaveScope(); + ScopeVisit(*node.body); } bool GlslWriter::HasExplicitBinding(ShaderAst::StatementPtr& shader) diff --git a/src/Nazara/Shader/LangWriter.cpp b/src/Nazara/Shader/LangWriter.cpp index 4646bc6ad..43d1b750c 100644 --- a/src/Nazara/Shader/LangWriter.cpp +++ b/src/Nazara/Shader/LangWriter.cpp @@ -29,63 +29,63 @@ namespace Nz struct LangWriter::BindingAttribute { - const ShaderAst::AttributeValue& bindingIndex; + const ShaderAst::ExpressionValue& bindingIndex; inline bool HasValue() const { return bindingIndex.HasValue(); } }; struct LangWriter::BuiltinAttribute { - const ShaderAst::AttributeValue& builtin; + const ShaderAst::ExpressionValue& builtin; inline bool HasValue() const { return builtin.HasValue(); } }; struct LangWriter::DepthWriteAttribute { - const ShaderAst::AttributeValue& writeMode; + const ShaderAst::ExpressionValue& writeMode; inline bool HasValue() const { return writeMode.HasValue(); } }; struct LangWriter::EarlyFragmentTestsAttribute { - const ShaderAst::AttributeValue& earlyFragmentTests; + const ShaderAst::ExpressionValue& earlyFragmentTests; inline bool HasValue() const { return earlyFragmentTests.HasValue(); } }; struct LangWriter::EntryAttribute { - const ShaderAst::AttributeValue& stageType; + const ShaderAst::ExpressionValue& stageType; inline bool HasValue() const { return stageType.HasValue(); } }; struct LangWriter::LayoutAttribute { - const ShaderAst::AttributeValue& layout; + const ShaderAst::ExpressionValue& layout; inline bool HasValue() const { return layout.HasValue(); } }; struct LangWriter::LocationAttribute { - const ShaderAst::AttributeValue& locationIndex; + const ShaderAst::ExpressionValue& locationIndex; inline bool HasValue() const { return locationIndex.HasValue(); } }; struct LangWriter::SetAttribute { - const ShaderAst::AttributeValue& setIndex; + const ShaderAst::ExpressionValue& setIndex; inline bool HasValue() const { return setIndex.HasValue(); } }; struct LangWriter::UnrollAttribute { - const ShaderAst::AttributeValue& unroll; + const ShaderAst::ExpressionValue& unroll; inline bool HasValue() const { return unroll.HasValue(); } }; @@ -126,14 +126,7 @@ namespace Nz void LangWriter::Append(const ShaderAst::ArrayType& type) { - Append("array[", type.containedType->type, ", "); - - if (type.length.IsResultingValue()) - Append(type.length.GetResultingValue()); - else - type.length.GetExpression()->Visit(*this); - - Append("]"); + Append("array[", type.containedType->type, ", ", type.length, "]"); } void LangWriter::Append(const ShaderAst::ExpressionType& type) @@ -144,11 +137,26 @@ namespace Nz }, type); } + void LangWriter::Append(const ShaderAst::ExpressionValue& type) + { + Append(type.GetResultingValue()); + } + + void LangWriter::Append(const ShaderAst::FunctionType& /*functionType*/) + { + throw std::runtime_error("unexpected function type"); + } + void LangWriter::Append(const ShaderAst::IdentifierType& /*identifierType*/) { throw std::runtime_error("unexpected identifier type"); } + void LangWriter::Append(const ShaderAst::IntrinsicFunctionType& /*functionType*/) + { + throw std::runtime_error("unexpected intrinsic function type"); + } + void LangWriter::Append(const ShaderAst::MatrixType& matrixType) { if (matrixType.columnCount == matrixType.rowCount) @@ -167,6 +175,11 @@ namespace Nz Append("[", matrixType.type, "]"); } + void LangWriter::Append(const ShaderAst::MethodType& /*functionType*/) + { + throw std::runtime_error("unexpected method type"); + } + void LangWriter::Append(ShaderAst::PrimitiveType type) { switch (type) @@ -201,14 +214,14 @@ namespace Nz Append(structDesc->name); } + void LangWriter::Append(const ShaderAst::Type& /*type*/) + { + throw std::runtime_error("unexpected type?"); + } + void LangWriter::Append(const ShaderAst::UniformType& uniformType) { - Append("uniform["); - std::visit([&](auto&& arg) - { - Append(arg); - }, uniformType.containedType); - Append("]"); + Append("uniform[", uniformType.containedType, "]"); } void LangWriter::Append(const ShaderAst::VectorType& vecType) @@ -411,6 +424,10 @@ namespace Nz { switch (entry.layout.GetResultingValue()) { + case StructLayout::Packed: + Append("packed"); + break; + case StructLayout::Std140: Append("std140"); break; @@ -558,6 +575,18 @@ namespace Nz m_currentState->variableNames.emplace(varIndex, std::move(varName)); } + void LangWriter::ScopeVisit(ShaderAst::Statement& node) + { + if (node.GetType() != ShaderAst::NodeType::ScopedStatement) + { + EnterScope(); + node.Visit(*this); + LeaveScope(true); + } + else + node.Visit(*this); + } + void LangWriter::Visit(ShaderAst::ExpressionPtr& expr, bool encloseIfRequired) { bool enclose = encloseIfRequired && (GetExpressionCategory(*expr) != ShaderAst::ExpressionCategory::LValue); @@ -590,12 +619,19 @@ namespace Nz assert(!IsStructType(exprType)); // Array access + Append("["); + + bool first = true; for (ShaderAst::ExpressionPtr& expr : node.indices) { - Append("["); + if (!first) + Append(", "); + expr->Visit(*this); - Append("]"); + first = false; } + + Append("]"); } void LangWriter::Visit(ShaderAst::AssignExpression& node) @@ -628,9 +664,7 @@ namespace Nz statement.condition->Visit(*this); AppendLine(")"); - EnterScope(); - statement.statement->Visit(*this); - LeaveScope(); + ScopeVisit(*statement.statement); first = false; } @@ -639,9 +673,7 @@ namespace Nz { AppendLine("else"); - EnterScope(); - node.elseStatement->Visit(*this); - LeaveScope(); + ScopeVisit(*node.elseStatement); } } @@ -800,8 +832,12 @@ namespace Nz RegisterVariable(varIndex++, node.parameters[i].name); } Append(")"); - if (!IsNoType(node.returnType)) - Append(" -> ", node.returnType); + if (node.returnType.HasValue()) + { + const ShaderAst::ExpressionType& returnType = node.returnType.GetResultingValue(); + if (!IsNoType(returnType)) + Append(" -> ", returnType); + } AppendLine(); EnterScope(); @@ -896,9 +932,7 @@ namespace Nz AppendLine(); - EnterScope(); - node.statement->Visit(*this); - LeaveScope(); + ScopeVisit(*node.statement); } void LangWriter::Visit(ShaderAst::ForEachStatement& node) @@ -911,9 +945,7 @@ namespace Nz node.expression->Visit(*this); AppendLine(); - EnterScope(); - node.statement->Visit(*this); - LeaveScope(); + ScopeVisit(*node.statement); } void LangWriter::Visit(ShaderAst::IntrinsicExpression& node) @@ -1001,6 +1033,13 @@ namespace Nz Append("return;"); } + void LangWriter::Visit(ShaderAst::ScopedStatement& node) + { + EnterScope(); + node.statement->Visit(*this); + LeaveScope(true); + } + void LangWriter::Visit(ShaderAst::SwizzleExpression& node) { Visit(node.expression, true); @@ -1043,9 +1082,7 @@ namespace Nz node.condition->Visit(*this); AppendLine(")"); - EnterScope(); - node.body->Visit(*this); - LeaveScope(); + ScopeVisit(*node.body); } void LangWriter::AppendHeader() diff --git a/src/Nazara/Shader/ShaderLangParser.cpp b/src/Nazara/Shader/ShaderLangParser.cpp index d73865aa0..c940145d9 100644 --- a/src/Nazara/Shader/ShaderLangParser.cpp +++ b/src/Nazara/Shader/ShaderLangParser.cpp @@ -20,13 +20,6 @@ namespace Nz::ShaderLang { "unchanged", ShaderAst::DepthWriteMode::Unchanged }, }; - std::unordered_map s_identifierToBasicType = { - { "bool", ShaderAst::PrimitiveType::Boolean }, - { "i32", ShaderAst::PrimitiveType::Int32 }, - { "f32", ShaderAst::PrimitiveType::Float32 }, - { "u32", ShaderAst::PrimitiveType::UInt32 } - }; - std::unordered_map s_identifierToAttributeType = { { "binding", ShaderAst::AttributeType::Binding }, { "builtin", ShaderAst::AttributeType::Builtin }, @@ -71,7 +64,7 @@ namespace Nz::ShaderLang } template - void HandleUniqueAttribute(const std::string_view& attributeName, ShaderAst::AttributeValue& targetAttribute, ShaderAst::Attribute::Param&& param, bool requireValue = true) + void HandleUniqueAttribute(const std::string_view& attributeName, ShaderAst::ExpressionValue& targetAttribute, ShaderAst::ExprValue::Param&& param, bool requireValue = true) { if (targetAttribute.HasValue()) throw AttributeError{ "attribute " + std::string(attributeName) + " must be present once" }; @@ -83,7 +76,7 @@ namespace Nz::ShaderLang } template - void HandleUniqueStringAttribute(const std::string_view& attributeName, const std::unordered_map& map, ShaderAst::AttributeValue& targetAttribute, ShaderAst::Attribute::Param&& param, std::optional defaultValue = {}) + void HandleUniqueStringAttribute(const std::string_view& attributeName, const std::unordered_map& map, ShaderAst::ExpressionValue& targetAttribute, ShaderAst::ExprValue::Param&& param, std::optional defaultValue = {}) { if (targetAttribute.HasValue()) throw AttributeError{ "attribute " + std::string(attributeName) + " must be present once" }; @@ -123,9 +116,7 @@ namespace Nz::ShaderLang m_context = &context; - std::vector attributes; - - EnterScope(); + std::vector attributes; bool reachedEndOfStream = false; while (!reachedEndOfStream) @@ -179,8 +170,6 @@ namespace Nz::ShaderLang } } - LeaveScope(); - return std::move(context.root); } @@ -198,161 +187,6 @@ namespace Nz::ShaderLang m_context->tokenIndex += count; } - std::optional Parser::DecodeType(const std::string& identifier) - { - if (auto it = s_identifierToBasicType.find(identifier); it != s_identifierToBasicType.end()) - { - Consume(); - return it->second; - } - - //FIXME: Handle this better - if (identifier == "array") - { - Consume(); - - Expect(Advance(), TokenType::OpenSquareBracket); //< [ - - ShaderAst::ArrayType arrayType; - arrayType.containedType = std::make_unique(); - arrayType.containedType->type = ParseType(); - - Expect(Advance(), TokenType::Comma); //< , - - arrayType.length = ParseExpression(); - - Expect(Advance(), TokenType::ClosingSquareBracket); //< ] - - return arrayType; - } - else if (identifier == "mat4") - { - Consume(); - - ShaderAst::MatrixType matrixType; - matrixType.columnCount = 4; - matrixType.rowCount = 4; - - Expect(Advance(), TokenType::OpenSquareBracket); //< [ - matrixType.type = ParsePrimitiveType(); - Expect(Advance(), TokenType::ClosingSquareBracket); //< ] - - return matrixType; - } - else if (identifier == "mat3") - { - Consume(); - - ShaderAst::MatrixType matrixType; - matrixType.columnCount = 3; - matrixType.rowCount = 3; - - Expect(Advance(), TokenType::OpenSquareBracket); //< [ - matrixType.type = ParsePrimitiveType(); - Expect(Advance(), TokenType::ClosingSquareBracket); //< ] - - return matrixType; - } - else if (identifier == "mat2") - { - Consume(); - - ShaderAst::MatrixType matrixType; - matrixType.columnCount = 2; - matrixType.rowCount = 2; - - Expect(Advance(), TokenType::OpenSquareBracket); //< [ - matrixType.type = ParsePrimitiveType(); - Expect(Advance(), TokenType::ClosingSquareBracket); //< ] - - return matrixType; - } - else if (identifier == "sampler2D") - { - Consume(); - - ShaderAst::SamplerType samplerType; - samplerType.dim = ImageType::E2D; - - Expect(Advance(), TokenType::OpenSquareBracket); //< [ - samplerType.sampledType = ParsePrimitiveType(); - Expect(Advance(), TokenType::ClosingSquareBracket); //< ] - - return samplerType; - } - else if (identifier == "samplerCube") - { - Consume(); - - ShaderAst::SamplerType samplerType; - samplerType.dim = ImageType::Cubemap; - - Expect(Advance(), TokenType::OpenSquareBracket); //< [ - samplerType.sampledType = ParsePrimitiveType(); - Expect(Advance(), TokenType::ClosingSquareBracket); //< ] - - return samplerType; - } - else if (identifier == "uniform") - { - Consume(); - - ShaderAst::UniformType uniformType; - - Expect(Advance(), TokenType::OpenSquareBracket); //< [ - uniformType.containedType = ShaderAst::IdentifierType{ ParseIdentifierAsName() }; - Expect(Advance(), TokenType::ClosingSquareBracket); //< ] - - return uniformType; - } - else if (identifier == "vec2") - { - Consume(); - - ShaderAst::VectorType vectorType; - vectorType.componentCount = 2; - - Expect(Advance(), TokenType::OpenSquareBracket); //< [ - vectorType.type = ParsePrimitiveType(); - Expect(Advance(), TokenType::ClosingSquareBracket); //< ] - - return vectorType; - } - else if (identifier == "vec3") - { - Consume(); - - ShaderAst::VectorType vectorType; - vectorType.componentCount = 3; - - Expect(Advance(), TokenType::OpenSquareBracket); //< [ - vectorType.type = ParsePrimitiveType(); - Expect(Advance(), TokenType::ClosingSquareBracket); //< ] - - return vectorType; - } - else if (identifier == "vec4") - { - Consume(); - - ShaderAst::VectorType vectorType; - vectorType.componentCount = 4; - - Expect(Advance(), TokenType::OpenSquareBracket); //< [ - vectorType.type = ParsePrimitiveType(); - Expect(Advance(), TokenType::ClosingSquareBracket); //< ] - - return vectorType; - } - else - return std::nullopt; - } - - void Parser::EnterScope() - { - m_context->scopeSizes.push_back(m_context->identifiersInScope.size()); - } - const Token& Parser::Expect(const Token& token, TokenType type) { if (token.type != type) @@ -377,33 +211,15 @@ namespace Nz::ShaderLang return token; } - void Parser::LeaveScope() - { - assert(!m_context->scopeSizes.empty()); - m_context->identifiersInScope.resize(m_context->scopeSizes.back()); - m_context->scopeSizes.pop_back(); - } - - bool Parser::IsVariableInScope(const std::string_view& identifier) const - { - return std::find(m_context->identifiersInScope.rbegin(), m_context->identifiersInScope.rend(), identifier) != m_context->identifiersInScope.rend(); - } - - void Parser::RegisterVariable(std::string identifier) - { - assert(!m_context->scopeSizes.empty()); - m_context->identifiersInScope.push_back(std::move(identifier)); - } - const Token& Parser::Peek(std::size_t advance) { assert(m_context->tokenIndex + advance < m_context->tokenCount); return m_context->tokens[m_context->tokenIndex + advance]; } - std::vector Parser::ParseAttributes() + std::vector Parser::ParseAttributes() { - std::vector attributes; + std::vector attributes; Expect(Advance(), TokenType::OpenSquareBracket); @@ -431,7 +247,7 @@ namespace Nz::ShaderLang ShaderAst::AttributeType attributeType = ParseIdentifierAsAttributeType(); - ShaderAst::Attribute::Param arg; + ShaderAst::ExprValue::Param arg; if (Peek().type == TokenType::OpenParenthesis) { Consume(); @@ -454,7 +270,7 @@ namespace Nz::ShaderLang return attributes; } - void Parser::ParseVariableDeclaration(std::string& name, ShaderAst::ExpressionType& type, ShaderAst::ExpressionPtr& initialValue) + void Parser::ParseVariableDeclaration(std::string& name, ShaderAst::ExpressionValue& type, ShaderAst::ExpressionPtr& initialValue) { name = ParseIdentifierAsName(); @@ -464,10 +280,8 @@ namespace Nz::ShaderLang type = ParseType(); } - else - type = ShaderAst::NoType{}; - if (IsNoType(type) || Peek().type == TokenType::Assign) + if (!type.HasValue() || Peek().type == TokenType::Assign) { Expect(Advance(), TokenType::Assign); initialValue = ParseExpression(); @@ -522,11 +336,10 @@ namespace Nz::ShaderLang case TokenType::Identifier: { std::string constName; - ShaderAst::ExpressionType constType; + ShaderAst::ExpressionValue constType; ShaderAst::ExpressionPtr initialValue; ParseVariableDeclaration(constName, constType, initialValue); - RegisterVariable(constName); return ShaderBuilder::DeclareConst(std::move(constName), std::move(constType), std::move(initialValue)); } @@ -552,14 +365,14 @@ namespace Nz::ShaderLang return ShaderBuilder::Discard(); } - ShaderAst::StatementPtr Parser::ParseExternalBlock(std::vector attributes) + ShaderAst::StatementPtr Parser::ParseExternalBlock(std::vector attributes) { Expect(Advance(), TokenType::External); Expect(Advance(), TokenType::OpenCurlyBracket); std::unique_ptr externalStatement = std::make_unique(); - ShaderAst::AttributeValue condition; + ShaderAst::ExpressionValue condition; for (auto&& [attributeType, arg] : attributes) { @@ -624,8 +437,6 @@ namespace Nz::ShaderLang extVar.name = ParseIdentifierAsName(); Expect(Advance(), TokenType::Colon); extVar.type = ParseType(); - - RegisterVariable(extVar.name); } Expect(Advance(), TokenType::ClosingCurlyBracket); @@ -636,7 +447,7 @@ namespace Nz::ShaderLang return externalStatement; } - ShaderAst::StatementPtr Parser::ParseForDeclaration(std::vector attributes) + ShaderAst::StatementPtr Parser::ParseForDeclaration(std::vector attributes) { Expect(Advance(), TokenType::For); @@ -710,7 +521,7 @@ namespace Nz::ShaderLang return ParseStatementList(); } - ShaderAst::StatementPtr Parser::ParseFunctionDeclaration(std::vector attributes) + ShaderAst::StatementPtr Parser::ParseFunctionDeclaration(std::vector attributes) { Expect(Advance(), TokenType::FunctionDeclaration); @@ -738,24 +549,18 @@ namespace Nz::ShaderLang Expect(Advance(), TokenType::ClosingParenthesis); - ShaderAst::ExpressionType returnType; + ShaderAst::ExpressionValue returnType; if (Peek().type == TokenType::Arrow) { Consume(); returnType = ParseType(); } - EnterScope(); - for (const auto& parameter : parameters) - RegisterVariable(parameter.name); - std::vector functionBody = ParseFunctionBody(); - LeaveScope(); - auto func = ShaderBuilder::DeclareFunction(std::move(functionName), std::move(parameters), std::move(functionBody), std::move(returnType)); - ShaderAst::AttributeValue condition; + ShaderAst::ExpressionValue condition; for (auto&& [attributeType, arg] : attributes) { @@ -794,7 +599,7 @@ namespace Nz::ShaderLang Expect(Advance(), TokenType::Colon); - ShaderAst::ExpressionType parameterType = ParseType(); + ShaderAst::ExpressionPtr parameterType = ParseType(); return { parameterName, std::move(parameterType) }; } @@ -807,7 +612,7 @@ namespace Nz::ShaderLang Expect(Advance(), TokenType::Colon); - ShaderAst::ExpressionType optionType = ParseType(); + ShaderAst::ExpressionPtr optionType = ParseType(); ShaderAst::ExpressionPtr initialValue; if (Peek().type == TokenType::Assign) @@ -837,7 +642,7 @@ namespace Nz::ShaderLang ShaderAst::StatementPtr Parser::ParseSingleStatement() { - std::vector attributes; + std::vector attributes; ShaderAst::StatementPtr statement; do { @@ -912,15 +717,13 @@ namespace Nz::ShaderLang ShaderAst::StatementPtr Parser::ParseStatement() { if (Peek().type == TokenType::OpenCurlyBracket) - return ShaderBuilder::MultiStatement(ParseStatementList()); + return ShaderBuilder::Scoped(ShaderBuilder::MultiStatement(ParseStatementList())); else return ParseSingleStatement(); } std::vector Parser::ParseStatementList() { - EnterScope(); - Expect(Advance(), TokenType::OpenCurlyBracket); std::vector statements; @@ -931,19 +734,17 @@ namespace Nz::ShaderLang } Consume(); //< Consume closing curly bracket - LeaveScope(); - return statements; } - ShaderAst::StatementPtr Parser::ParseStructDeclaration(std::vector attributes) + ShaderAst::StatementPtr Parser::ParseStructDeclaration(std::vector attributes) { Expect(Advance(), TokenType::Struct); ShaderAst::StructDescription description; description.name = ParseIdentifierAsName(); - ShaderAst::AttributeValue condition; + ShaderAst::ExpressionValue condition; for (auto&& [attributeType, attributeParam] : attributes) { @@ -1065,16 +866,15 @@ namespace Nz::ShaderLang Expect(Advance(), TokenType::Let); std::string variableName; - ShaderAst::ExpressionType variableType; + ShaderAst::ExpressionValue variableType; ShaderAst::ExpressionPtr expression; ParseVariableDeclaration(variableName, variableType, expression); - RegisterVariable(variableName); return ShaderBuilder::DeclareVariable(std::move(variableName), std::move(variableType), std::move(expression)); } - ShaderAst::StatementPtr Parser::ParseWhileStatement(std::vector attributes) + ShaderAst::StatementPtr Parser::ParseWhileStatement(std::vector attributes) { Expect(Advance(), TokenType::While); @@ -1133,19 +933,6 @@ namespace Nz::ShaderLang accessMemberNode->identifiers.push_back(ParseIdentifierAsName()); } while (Peek().type == TokenType::Dot); - // FIXME - if (!accessMemberNode->identifiers.empty() && accessMemberNode->identifiers.front() == "Sample") - { - if (Peek().type == TokenType::OpenParenthesis) - { - auto parameters = ParseParameters(); - parameters.insert(parameters.begin(), std::move(accessMemberNode->expr)); - - lhs = ShaderBuilder::Intrinsic(ShaderAst::IntrinsicType::SampleTexture, std::move(parameters)); - break; - } - } - lhs = std::move(accessMemberNode); } else @@ -1160,10 +947,10 @@ namespace Nz::ShaderLang Consume(); indexNode->indices.push_back(ParseExpression()); - - Expect(Advance(), TokenType::ClosingSquareBracket); } - while (Peek().type == TokenType::OpenSquareBracket); + while (Peek().type == TokenType::Comma); + + Expect(Advance(), TokenType::ClosingSquareBracket); lhs = std::move(indexNode); } @@ -1171,6 +958,15 @@ namespace Nz::ShaderLang currentTokenType = Peek().type; } + if (currentTokenType == TokenType::OpenParenthesis) + { + // Function call + auto parameters = ParseParameters(); + lhs = ShaderBuilder::CallFunction(std::move(lhs), std::move(parameters)); + + c = true; + } + if (c) continue; @@ -1302,23 +1098,7 @@ namespace Nz::ShaderLang return ParseFloatingPointExpression(); case TokenType::Identifier: - { - const std::string& identifier = std::get(token.data); - - // Is it a cast? - std::optional exprType = DecodeType(identifier); - if (exprType) - return ShaderBuilder::Cast(std::move(*exprType), ParseParameters()); - - if (Peek(1).type == TokenType::OpenParenthesis) - { - // Function call - Consume(); - return ShaderBuilder::CallFunction(identifier, ParseParameters()); - } - else - return ParseIdentifier(); - } + return ParseIdentifier(); case TokenType::IntegerValue: return ParseIntegerExpression(); @@ -1370,28 +1150,10 @@ namespace Nz::ShaderLang const std::string& Parser::ParseIdentifierAsName() { const Token& identifierToken = Expect(Advance(), TokenType::Identifier); - const std::string& identifier = std::get(identifierToken.data); - - auto it = s_identifierToBasicType.find(identifier); - if (it != s_identifierToBasicType.end()) - throw ReservedKeyword{}; - - return identifier; + return std::get(identifierToken.data); } - ShaderAst::PrimitiveType Parser::ParsePrimitiveType() - { - const Token& identifierToken = Expect(Advance(), TokenType::Identifier); - const std::string& identifier = std::get(identifierToken.data); - - auto it = s_identifierToBasicType.find(identifier); - if (it == s_identifierToBasicType.end()) - throw UnknownType{}; - - return it->second; - } - - ShaderAst::ExpressionType Parser::ParseType() + ShaderAst::ExpressionPtr Parser::ParseType() { // Handle () as no type if (Peek().type == TokenType::OpenParenthesis) @@ -1399,20 +1161,10 @@ namespace Nz::ShaderLang Consume(); Expect(Advance(), TokenType::ClosingParenthesis); - return ShaderAst::NoType{}; + return ShaderBuilder::Constant(ShaderAst::NoValue{}); } - const Token& identifierToken = Expect(Peek(), TokenType::Identifier); - const std::string& identifier = std::get(identifierToken.data); - - auto type = DecodeType(identifier); - if (!type) - { - Consume(); - return ShaderAst::IdentifierType{ identifier }; - } - - return *std::move(type); + return ParseExpression(); } int Parser::GetTokenPrecedence(TokenType token) @@ -1433,6 +1185,7 @@ namespace Nz::ShaderLang case TokenType::NotEqual: return 50; case TokenType::Plus: return 60; case TokenType::OpenSquareBracket: return 100; + case TokenType::OpenParenthesis: return 100; default: return -1; } } diff --git a/src/Nazara/Shader/SpirvAstVisitor.cpp b/src/Nazara/Shader/SpirvAstVisitor.cpp index 224b223d7..381b3c555 100644 --- a/src/Nazara/Shader/SpirvAstVisitor.cpp +++ b/src/Nazara/Shader/SpirvAstVisitor.cpp @@ -400,8 +400,7 @@ namespace Nz void SpirvAstVisitor::Visit(ShaderAst::CallFunctionExpression& node) { - assert(std::holds_alternative(node.targetFunction)); - std::size_t functionIndex = std::get(node.targetFunction); + std::size_t functionIndex = std::get(GetExpressionType(*node.targetFunction)).funcIndex; UInt32 funcId = 0; for (const auto& [funcIndex, func] : m_funcData) @@ -443,7 +442,7 @@ namespace Nz void SpirvAstVisitor::Visit(ShaderAst::CastExpression& node) { - const ShaderAst::ExpressionType& targetExprType = node.targetType; + const ShaderAst::ExpressionType& targetExprType = node.targetType.GetResultingValue(); if (IsPrimitiveType(targetExprType)) { ShaderAst::PrimitiveType targetType = std::get(targetExprType); @@ -584,7 +583,7 @@ namespace Nz std::size_t varIndex = *node.varIndex; for (auto&& extVar : node.externalVars) - RegisterExternalVariable(varIndex++, extVar.type); + RegisterExternalVariable(varIndex++, extVar.type.GetResultingValue()); } void SpirvAstVisitor::Visit(ShaderAst::DeclareFunctionStatement& node) @@ -674,7 +673,7 @@ namespace Nz { const auto& func = m_funcData[m_funcIndex]; - UInt32 typeId = m_writer.GetTypeId(node.varType); + UInt32 typeId = m_writer.GetTypeId(node.varType.GetResultingValue()); assert(node.varIndex); auto varIt = func.varIndexToVarId.find(*node.varIndex); @@ -932,6 +931,11 @@ namespace Nz m_currentBlock->Append(SpirvOp::OpReturn); } + void SpirvAstVisitor::Visit(ShaderAst::ScopedStatement& node) + { + node.statement->Visit(*this); + } + void SpirvAstVisitor::Visit(ShaderAst::SwizzleExpression& node) { const ShaderAst::ExpressionType& swizzledExpressionType = GetExpressionType(*node.expression); diff --git a/src/Nazara/Shader/SpirvConstantCache.cpp b/src/Nazara/Shader/SpirvConstantCache.cpp index 2d65f26e1..d7e94a5c1 100644 --- a/src/Nazara/Shader/SpirvConstantCache.cpp +++ b/src/Nazara/Shader/SpirvConstantCache.cpp @@ -513,7 +513,7 @@ namespace Nz for (const Structure::Member& member : structData.members) { - member.offset = std::visit([&](auto&& arg) -> std::size_t + member.offset = SafeCast(std::visit([&](auto&& arg) -> std::size_t { using T = std::decay_t; @@ -601,7 +601,7 @@ namespace Nz throw std::runtime_error("unexpected void as struct member"); else static_assert(AlwaysFalse::value, "non-exhaustive visitor"); - }, member.type->type); + }, member.type->type)); } return structOffsets; @@ -671,7 +671,7 @@ namespace Nz return std::make_shared(Array{ builtContainedType, - BuildConstant(type.length.GetResultingValue()), + BuildConstant(type.length), arrayStride }); } @@ -802,7 +802,7 @@ namespace Nz auto& sMembers = sType.members.emplace_back(); sMembers.name = member.name; - sMembers.type = BuildType(member.type); + sMembers.type = BuildType(member.type.GetResultingValue()); } m_internal->isInBlockStruct = wasInBlock; @@ -817,8 +817,7 @@ namespace Nz auto SpirvConstantCache::BuildType(const ShaderAst::UniformType& type) const -> TypePtr { - assert(std::holds_alternative(type.containedType)); - return BuildType(std::get(type.containedType)); + return BuildType(type.containedType); } UInt32 SpirvConstantCache::GetId(const Constant& c) @@ -918,12 +917,12 @@ namespace Nz return fieldOffsets.AddFieldArray(TypeToStructFieldType(type), arrayLength); } - std::size_t SpirvConstantCache::RegisterArrayField(FieldOffsets& fieldOffsets, const Function& type, std::size_t arrayLength) const + std::size_t SpirvConstantCache::RegisterArrayField(FieldOffsets& /*fieldOffsets*/, const Function& /*type*/, std::size_t /*arrayLength*/) const { throw std::runtime_error("unexpected Function"); } - std::size_t SpirvConstantCache::RegisterArrayField(FieldOffsets& fieldOffsets, const Image& type, std::size_t arrayLength) const + std::size_t SpirvConstantCache::RegisterArrayField(FieldOffsets& /*fieldOffsets*/, const Image& /*type*/, std::size_t /*arrayLength*/) const { throw std::runtime_error("unexpected Image"); } diff --git a/src/Nazara/Shader/SpirvExpressionLoad.cpp b/src/Nazara/Shader/SpirvExpressionLoad.cpp index c43ef8653..e234d64bd 100644 --- a/src/Nazara/Shader/SpirvExpressionLoad.cpp +++ b/src/Nazara/Shader/SpirvExpressionLoad.cpp @@ -30,9 +30,46 @@ namespace Nz return resultId; }, + [this](const PointerChainAccess& pointerChainAccess) -> UInt32 + { + UInt32 pointerType = m_writer.RegisterPointerType(*pointerChainAccess.exprType, pointerChainAccess.storage); //< FIXME: We shouldn't register this so late + + UInt32 pointerId = m_visitor.AllocateResultId(); + + m_block.AppendVariadic(SpirvOp::OpAccessChain, [&](const auto& appender) + { + appender(pointerType); + appender(pointerId); + appender(pointerChainAccess.pointerId); + + for (UInt32 id : pointerChainAccess.indices) + appender(id); + }); + + UInt32 resultId = m_visitor.AllocateResultId(); + m_block.Append(SpirvOp::OpLoad, m_writer.GetTypeId(*pointerChainAccess.exprType), resultId, pointerId); + + return resultId; + }, [](const Value& value) -> UInt32 { - return value.resultId; + return value.valueId; + }, + [this](const ValueExtraction& extractedValue) -> UInt32 + { + UInt32 resultId = m_visitor.AllocateResultId(); + + m_block.AppendVariadic(SpirvOp::OpCompositeExtract, [&](const auto& appender) + { + appender(extractedValue.typeId); + appender(resultId); + appender(extractedValue.valueId); + + for (UInt32 id : extractedValue.indices) + appender(id); + }); + + return resultId; }, [](std::monostate) -> UInt32 { @@ -47,48 +84,42 @@ namespace Nz const ShaderAst::ExpressionType& exprType = GetExpressionType(node); - UInt32 resultId = m_visitor.AllocateResultId(); UInt32 typeId = m_writer.GetTypeId(exprType); + assert(node.indices.size() == 1); + UInt32 indexId = m_visitor.EvaluateExpression(node.indices.front()); + std::visit(overloaded { [&](const Pointer& pointer) { - UInt32 pointerType = m_writer.RegisterPointerType(exprType, pointer.storage); //< FIXME + PointerChainAccess pointerChainAccess; + pointerChainAccess.exprType = &exprType; + pointerChainAccess.indices = { indexId }; + pointerChainAccess.pointedTypeId = pointer.pointedTypeId; + pointerChainAccess.pointerId = pointer.pointerId; + pointerChainAccess.storage = pointer.storage; - StackArray indexIds = NazaraStackArrayNoInit(UInt32, node.indices.size()); - for (std::size_t i = 0; i < node.indices.size(); ++i) - indexIds[i] = m_visitor.EvaluateExpression(node.indices[i]); - - m_block.AppendVariadic(SpirvOp::OpAccessChain, [&](const auto& appender) - { - appender(pointerType); - appender(resultId); - appender(pointer.pointerId); - - for (UInt32 id : indexIds) - appender(id); - }); - - m_value = Pointer { pointer.storage, resultId, typeId }; + m_value = std::move(pointerChainAccess); + }, + [&](PointerChainAccess& pointerChainAccess) + { + pointerChainAccess.exprType = &exprType; + pointerChainAccess.indices.push_back(indexId); }, [&](const Value& value) { - StackArray indexIds = NazaraStackArrayNoInit(UInt32, node.indices.size()); - for (std::size_t i = 0; i < node.indices.size(); ++i) - indexIds[i] = m_visitor.EvaluateExpression(node.indices[i]); + ValueExtraction extractedValue; + extractedValue.indices = { indexId }; + extractedValue.typeId = typeId; + extractedValue.valueId = value.valueId; - m_block.AppendVariadic(SpirvOp::OpCompositeExtract, [&](const auto& appender) - { - appender(typeId); - appender(resultId); - appender(value.resultId); - - for (UInt32 id : indexIds) - appender(id); - }); - - m_value = Value { resultId }; + m_value = std::move(extractedValue); + }, + [&](ValueExtraction& extractedValue) + { + extractedValue.indices.push_back(indexId); + extractedValue.typeId = typeId; }, [](std::monostate) { diff --git a/src/Nazara/Shader/SpirvExpressionStore.cpp b/src/Nazara/Shader/SpirvExpressionStore.cpp index d5f749e7c..571b01ef4 100644 --- a/src/Nazara/Shader/SpirvExpressionStore.cpp +++ b/src/Nazara/Shader/SpirvExpressionStore.cpp @@ -100,19 +100,10 @@ namespace Nz UInt32 resultId = m_visitor.AllocateResultId(); UInt32 pointerType = m_writer.RegisterPointerType(exprType, pointer.storage); //< FIXME - StackArray indexIds = NazaraStackArrayNoInit(UInt32, node.indices.size()); - for (std::size_t i = 0; i < node.indices.size(); ++i) - indexIds[i] = m_visitor.EvaluateExpression(node.indices[i]); + assert(node.indices.size() == 1); + UInt32 indexId = m_visitor.EvaluateExpression(node.indices.front()); - m_block.AppendVariadic(SpirvOp::OpAccessChain, [&](const auto& appender) - { - appender(pointerType); - appender(resultId); - appender(pointer.pointerId); - - for (UInt32 id : indexIds) - appender(id); - }); + m_block.Append(SpirvOp::OpAccessChain, pointerType, resultId, pointer.pointerId, indexId); m_value = Pointer { pointer.storage, resultId }; }, @@ -147,6 +138,8 @@ namespace Nz { // Swizzle the swizzle, keep common components std::array newIndices; + newIndices.fill(0); //< keep compiler happy + for (std::size_t i = 0; i < node.componentCount; ++i) { assert(node.components[i] < swizzledPointer.componentCount); diff --git a/src/Nazara/Shader/SpirvWriter.cpp b/src/Nazara/Shader/SpirvWriter.cpp index a4ecc52d5..a1414e254 100644 --- a/src/Nazara/Shader/SpirvWriter.cpp +++ b/src/Nazara/Shader/SpirvWriter.cpp @@ -144,17 +144,18 @@ namespace Nz SpirvConstantCache::Variable variable; variable.debugName = extVar.name; - if (ShaderAst::IsSamplerType(extVar.type)) + const ShaderAst::ExpressionType& extVarType = extVar.type.GetResultingValue(); + + if (ShaderAst::IsSamplerType(extVarType)) { variable.storageClass = SpirvStorageClass::UniformConstant; - variable.type = m_constantCache.BuildPointerType(extVar.type, variable.storageClass); + variable.type = m_constantCache.BuildPointerType(extVarType, variable.storageClass); } else { - assert(ShaderAst::IsUniformType(extVar.type)); - const auto& uniformType = std::get(extVar.type); - assert(std::holds_alternative(uniformType.containedType)); - const auto& structType = std::get(uniformType.containedType); + assert(ShaderAst::IsUniformType(extVarType)); + const auto& uniformType = std::get(extVarType); + const auto& structType = uniformType.containedType; assert(structType.structIndex < declaredStructs.size()); const auto& type = m_constantCache.BuildType(*declaredStructs[structType.structIndex], { SpirvDecoration::Block }); @@ -188,16 +189,27 @@ namespace Nz { std::vector parameterTypes; for (auto& parameter : node.parameters) - parameterTypes.push_back(parameter.type); + parameterTypes.push_back(parameter.type.GetResultingValue()); - funcData.returnTypeId = m_constantCache.Register(*m_constantCache.BuildType(node.returnType)); - funcData.funcTypeId = m_constantCache.Register(*m_constantCache.BuildFunctionType(node.returnType, parameterTypes)); + if (node.returnType.HasValue()) + { + const auto& returnType = node.returnType.GetResultingValue(); + funcData.returnTypeId = m_constantCache.Register(*m_constantCache.BuildType(returnType)); + funcData.funcTypeId = m_constantCache.Register(*m_constantCache.BuildFunctionType(returnType, parameterTypes)); + } + else + { + funcData.returnTypeId = m_constantCache.Register(*m_constantCache.BuildType(ShaderAst::NoType{})); + funcData.funcTypeId = m_constantCache.Register(*m_constantCache.BuildFunctionType(ShaderAst::NoType{}, parameterTypes)); + } for (auto& parameter : node.parameters) { + const auto& parameterType = parameter.type.GetResultingValue(); + auto& funcParam = funcData.parameters.emplace_back(); - funcParam.pointerTypeId = m_constantCache.Register(*m_constantCache.BuildPointerType(parameter.type, SpirvStorageClass::Function)); - funcParam.typeId = m_constantCache.Register(*m_constantCache.BuildType(parameter.type)); + funcParam.pointerTypeId = m_constantCache.Register(*m_constantCache.BuildPointerType(parameterType, SpirvStorageClass::Function)); + funcParam.typeId = m_constantCache.Register(*m_constantCache.BuildType(parameterType)); } } else @@ -235,9 +247,11 @@ namespace Nz { assert(node.parameters.size() == 1); auto& parameter = node.parameters.front(); - assert(std::holds_alternative(parameter.type)); + const auto& parameterType = parameter.type.GetResultingValue(); - std::size_t structIndex = std::get(parameter.type).structIndex; + assert(std::holds_alternative(parameterType)); + + std::size_t structIndex = std::get(parameterType).structIndex; const ShaderAst::StructDescription* structDesc = declaredStructs[structIndex]; std::size_t memberIndex = 0; @@ -250,7 +264,7 @@ namespace Nz { inputs.push_back({ m_constantCache.Register(*m_constantCache.BuildConstant(Int32(memberIndex))), - m_constantCache.Register(*m_constantCache.BuildPointerType(member.type, SpirvStorageClass::Function)), + m_constantCache.Register(*m_constantCache.BuildPointerType(member.type.GetResultingValue(), SpirvStorageClass::Function)), varId }); } @@ -259,18 +273,20 @@ namespace Nz } inputStruct = EntryPoint::InputStruct{ - m_constantCache.Register(*m_constantCache.BuildPointerType(parameter.type, SpirvStorageClass::Function)), - m_constantCache.Register(*m_constantCache.BuildType(parameter.type)) + m_constantCache.Register(*m_constantCache.BuildPointerType(parameterType, SpirvStorageClass::Function)), + m_constantCache.Register(*m_constantCache.BuildType(parameter.type.GetResultingValue())) }; } std::optional outputStructId; std::vector outputs; - if (!IsNoType(node.returnType)) + if (node.returnType.HasValue()) { - assert(std::holds_alternative(node.returnType)); + const ShaderAst::ExpressionType& returnType = node.returnType.GetResultingValue(); - std::size_t structIndex = std::get(node.returnType).structIndex; + assert(std::holds_alternative(returnType)); + + std::size_t structIndex = std::get(returnType).structIndex; const ShaderAst::StructDescription* structDesc = declaredStructs[structIndex]; std::size_t memberIndex = 0; @@ -283,7 +299,7 @@ namespace Nz { outputs.push_back({ Int32(memberIndex), - m_constantCache.Register(*m_constantCache.BuildType(member.type)), + m_constantCache.Register(*m_constantCache.BuildType(member.type.GetResultingValue())), varId }); } @@ -291,7 +307,7 @@ namespace Nz memberIndex++; } - outputStructId = m_constantCache.Register(*m_constantCache.BuildType(node.returnType)); + outputStructId = m_constantCache.Register(*m_constantCache.BuildType(returnType)); } funcData.entryPointData = EntryPoint{ @@ -334,7 +350,7 @@ namespace Nz func.varIndexToVarId[*node.varIndex] = func.variables.size(); auto& var = func.variables.emplace_back(); - var.typeId = m_constantCache.Register(*m_constantCache.BuildPointerType(node.varType, SpirvStorageClass::Function)); + var.typeId = m_constantCache.Register(*m_constantCache.BuildPointerType(node.varType.GetResultingValue(), SpirvStorageClass::Function)); } void Visit(ShaderAst::IdentifierExpression& node) override @@ -408,7 +424,7 @@ namespace Nz variable.debugName = builtin.debugName; variable.funcId = funcIndex; variable.storageClass = storageClass; - variable.type = m_constantCache.BuildPointerType(member.type, storageClass); + variable.type = m_constantCache.BuildPointerType(member.type.GetResultingValue(), storageClass); UInt32 varId = m_constantCache.Register(variable); builtinDecorations[varId] = builtinDecoration; @@ -421,7 +437,7 @@ namespace Nz variable.debugName = member.name; variable.funcId = funcIndex; variable.storageClass = storageClass; - variable.type = m_constantCache.BuildPointerType(member.type, storageClass); + variable.type = m_constantCache.BuildPointerType(member.type.GetResultingValue(), storageClass); UInt32 varId = m_constantCache.Register(variable); locationDecorations[varId] = member.locationIndex.GetResultingValue(); @@ -643,9 +659,12 @@ namespace Nz parameterTypes.reserve(functionNode.parameters.size()); for (const auto& parameter : functionNode.parameters) - parameterTypes.push_back(parameter.type); + parameterTypes.push_back(parameter.type.GetResultingValue()); - return m_currentState->constantTypeCache.BuildFunctionType(functionNode.returnType, parameterTypes); + if (functionNode.returnType.HasValue()) + return m_currentState->constantTypeCache.BuildFunctionType(functionNode.returnType.GetResultingValue(), parameterTypes); + else + return m_currentState->constantTypeCache.BuildFunctionType(ShaderAst::NoType{}, parameterTypes); } UInt32 SpirvWriter::GetConstantId(const ShaderAst::ConstantValue& value) const diff --git a/tests/Engine/Shader/AccessMemberTest.cpp b/tests/Engine/Shader/AccessMemberTest.cpp index ac389c868..dbec148ec 100644 --- a/tests/Engine/Shader/AccessMemberTest.cpp +++ b/tests/Engine/Shader/AccessMemberTest.cpp @@ -38,7 +38,7 @@ external auto secondAccess = Nz::ShaderBuilder::AccessMember(std::move(firstAccess), { "field" }); auto swizzle = Nz::ShaderBuilder::Swizzle(std::move(secondAccess), { 2u }); - auto varDecl = Nz::ShaderBuilder::DeclareVariable("result", Nz::ShaderAst::PrimitiveType::Float32, std::move(swizzle)); + auto varDecl = Nz::ShaderBuilder::DeclareVariable("result", Nz::ShaderAst::ExpressionType{ Nz::ShaderAst::PrimitiveType::Float32 }, std::move(swizzle)); multiStatement.statements.push_back(Nz::ShaderBuilder::DeclareFunction(Nz::ShaderStageType::Vertex, "main", std::move(varDecl))); @@ -75,7 +75,7 @@ OpFunctionEnd)"); auto access = Nz::ShaderBuilder::AccessMember(std::move(ubo), { "s", "field" }); auto swizzle = Nz::ShaderBuilder::Swizzle(std::move(access), { 2u }); - auto varDecl = Nz::ShaderBuilder::DeclareVariable("result", Nz::ShaderAst::PrimitiveType::Float32, std::move(swizzle)); + auto varDecl = Nz::ShaderBuilder::DeclareVariable("result", Nz::ShaderAst::ExpressionType{ Nz::ShaderAst::PrimitiveType::Float32 }, std::move(swizzle)); multiStatement.statements.push_back(Nz::ShaderBuilder::DeclareFunction(Nz::ShaderStageType::Vertex, "main", std::move(varDecl))); diff --git a/tests/Engine/Shader/Loops.cpp b/tests/Engine/Shader/Loops.cpp index dbab1022f..2f9c019fe 100644 --- a/tests/Engine/Shader/Loops.cpp +++ b/tests/Engine/Shader/Loops.cpp @@ -315,7 +315,6 @@ OpULessThan OpLoopMerge OpBranchConditional OpLabel -OpAccessChain OpLoad OpAccessChain OpLoad diff --git a/tests/Engine/Shader/Optimizations.cpp b/tests/Engine/Shader/Optimizations.cpp index edf34e8d5..2ed7446ba 100644 --- a/tests/Engine/Shader/Optimizations.cpp +++ b/tests/Engine/Shader/Optimizations.cpp @@ -33,6 +33,7 @@ fn main() fn main() { let output: f32 = 42.000000; +} )"); } diff --git a/tests/Engine/Shader/Swizzle.cpp b/tests/Engine/Shader/Swizzle.cpp index 9ffeac2ab..6922e9424 100644 --- a/tests/Engine/Shader/Swizzle.cpp +++ b/tests/Engine/Shader/Swizzle.cpp @@ -150,8 +150,8 @@ OpFunctionEnd)"); [entry(frag)] fn main() { - let vec = max(2.0, 1.0).xxx; - let vec2 = min(2.0, 1.0).xxx; + let v = max(2.0, 1.0).xxx; + let v2 = min(2.0, 1.0).xxx; } )"; @@ -161,9 +161,9 @@ fn main() void main() { float cachedResult = max(2.000000, 1.000000); - vec3 vec = vec3(cachedResult, cachedResult, cachedResult); + vec3 v = vec3(cachedResult, cachedResult, cachedResult); float cachedResult_2 = min(2.000000, 1.000000); - vec3 vec2_ = vec3(cachedResult_2, cachedResult_2, cachedResult_2); + vec3 v2 = vec3(cachedResult_2, cachedResult_2, cachedResult_2); } )"); @@ -171,8 +171,8 @@ void main() [entry(frag)] fn main() { - let vec: vec3[f32] = (max(2.000000, 1.000000)).xxx; - let vec2: vec3[f32] = (min(2.000000, 1.000000)).xxx; + let v: vec3[f32] = (max(2.000000, 1.000000)).xxx; + let v2: vec3[f32] = (min(2.000000, 1.000000)).xxx; } )");