From fed7370e7733ff5ee9856f4d0541735798ea9ff9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=B4me=20Leclercq?= Date: Wed, 10 Mar 2021 11:18:13 +0100 Subject: [PATCH] Rework shader AST (WIP) --- include/Nazara/Renderer/RenderDevice.hpp | 4 +- include/Nazara/Shader.hpp | 17 +- include/Nazara/Shader/GlslWriter.hpp | 86 +-- include/Nazara/Shader/GlslWriter.inl | 120 ---- include/Nazara/Shader/ShaderAst.hpp | 130 ---- include/Nazara/Shader/ShaderAst.inl | 128 ---- include/Nazara/Shader/ShaderAstCache.hpp | 48 ++ include/Nazara/Shader/ShaderAstCache.inl | 37 + include/Nazara/Shader/ShaderAstCloner.hpp | 90 ++- .../Nazara/Shader/ShaderAstExpressionType.hpp | 57 ++ .../Nazara/Shader/ShaderAstExpressionType.inl | 17 + .../Shader/ShaderAstExpressionVisitor.hpp | 32 + .../ShaderAstExpressionVisitorExcept.hpp | 26 + include/Nazara/Shader/ShaderAstNodes.hpp | 53 ++ include/Nazara/Shader/ShaderAstOptimizer.hpp | 35 +- .../Shader/ShaderAstRecursiveVisitor.hpp | 49 +- include/Nazara/Shader/ShaderAstSerializer.hpp | 84 ++- include/Nazara/Shader/ShaderAstSerializer.inl | 42 +- .../Shader/ShaderAstStatementVisitor.hpp | 32 + .../ShaderAstStatementVisitorExcept.hpp | 26 + ...rExpressionType.hpp => ShaderAstTypes.hpp} | 25 +- ...rExpressionType.inl => ShaderAstTypes.inl} | 12 +- include/Nazara/Shader/ShaderAstUtils.hpp | 36 +- include/Nazara/Shader/ShaderAstUtils.inl | 4 +- include/Nazara/Shader/ShaderAstValidator.hpp | 78 +-- include/Nazara/Shader/ShaderAstValidator.inl | 6 +- include/Nazara/Shader/ShaderAstVisitor.hpp | 49 -- .../Nazara/Shader/ShaderAstVisitorExcept.hpp | 41 -- include/Nazara/Shader/ShaderBuilder.hpp | 100 ++- include/Nazara/Shader/ShaderBuilder.inl | 116 ++- include/Nazara/Shader/ShaderEnums.hpp | 39 +- include/Nazara/Shader/ShaderEnums.inl | 57 ++ include/Nazara/Shader/ShaderLangParser.hpp | 33 +- include/Nazara/Shader/ShaderLangTokenList.hpp | 4 +- include/Nazara/Shader/ShaderNodes.hpp | 505 ++++++-------- include/Nazara/Shader/ShaderNodes.inl | 388 +---------- include/Nazara/Shader/ShaderVarVisitor.hpp | 38 - .../Nazara/Shader/ShaderVarVisitorExcept.hpp | 28 - include/Nazara/Shader/ShaderVariables.hpp | 128 ---- include/Nazara/Shader/ShaderVariables.inl | 65 -- include/Nazara/Shader/ShaderWriter.hpp | 2 - include/Nazara/Shader/SpirvAstVisitor.hpp | 47 +- include/Nazara/Shader/SpirvConstantCache.hpp | 11 +- include/Nazara/Shader/SpirvExpressionLoad.hpp | 19 +- .../Nazara/Shader/SpirvExpressionStore.hpp | 20 +- include/Nazara/Shader/SpirvWriter.hpp | 31 +- include/Nazara/Shader/SpirvWriter.inl | 5 +- src/Nazara/Graphics/BasicMaterial.cpp | 4 +- src/Nazara/Graphics/UberShader.cpp | 10 +- src/Nazara/Shader/GlslWriter.cpp | 454 +++++------- src/Nazara/Shader/ShaderAst.cpp | 56 -- src/Nazara/Shader/ShaderAstCloner.cpp | 255 +++---- src/Nazara/Shader/ShaderAstExpressionType.cpp | 198 ++++++ ...tor.cpp => ShaderAstExpressionVisitor.cpp} | 11 +- .../ShaderAstExpressionVisitorExcept.cpp | 15 + src/Nazara/Shader/ShaderAstOptimizer.cpp | 166 ++--- .../Shader/ShaderAstRecursiveVisitor.cpp | 129 ++-- src/Nazara/Shader/ShaderAstSerializer.cpp | 659 +++++------------- ...itor.cpp => ShaderAstStatementVisitor.cpp} | 11 +- .../ShaderAstStatementVisitorExcept.cpp | 15 + src/Nazara/Shader/ShaderAstUtils.cpp | 60 +- src/Nazara/Shader/ShaderAstValidator.cpp | 543 +++++++-------- src/Nazara/Shader/ShaderAstVisitorExcept.cpp | 100 --- src/Nazara/Shader/ShaderLangLexer.cpp | 10 +- src/Nazara/Shader/ShaderLangParser.cpp | 138 ++-- src/Nazara/Shader/ShaderNodes.cpp | 268 +------ src/Nazara/Shader/ShaderVarVisitorExcept.cpp | 40 -- src/Nazara/Shader/ShaderVariables.cpp | 77 -- src/Nazara/Shader/SpirvAstVisitor.cpp | 475 ++++++------- src/Nazara/Shader/SpirvConstantCache.cpp | 64 +- src/Nazara/Shader/SpirvExpressionLoad.cpp | 43 +- src/Nazara/Shader/SpirvExpressionStore.cpp | 33 +- src/Nazara/Shader/SpirvWriter.cpp | 199 +++--- 73 files changed, 2721 insertions(+), 4312 deletions(-) delete mode 100644 include/Nazara/Shader/ShaderAst.hpp delete mode 100644 include/Nazara/Shader/ShaderAst.inl create mode 100644 include/Nazara/Shader/ShaderAstCache.hpp create mode 100644 include/Nazara/Shader/ShaderAstCache.inl create mode 100644 include/Nazara/Shader/ShaderAstExpressionType.hpp create mode 100644 include/Nazara/Shader/ShaderAstExpressionType.inl create mode 100644 include/Nazara/Shader/ShaderAstExpressionVisitor.hpp create mode 100644 include/Nazara/Shader/ShaderAstExpressionVisitorExcept.hpp create mode 100644 include/Nazara/Shader/ShaderAstNodes.hpp create mode 100644 include/Nazara/Shader/ShaderAstStatementVisitor.hpp create mode 100644 include/Nazara/Shader/ShaderAstStatementVisitorExcept.hpp rename include/Nazara/Shader/{ShaderExpressionType.hpp => ShaderAstTypes.hpp} (56%) rename include/Nazara/Shader/{ShaderExpressionType.inl => ShaderAstTypes.inl} (89%) delete mode 100644 include/Nazara/Shader/ShaderAstVisitor.hpp delete mode 100644 include/Nazara/Shader/ShaderAstVisitorExcept.hpp create mode 100644 include/Nazara/Shader/ShaderEnums.inl delete mode 100644 include/Nazara/Shader/ShaderVarVisitor.hpp delete mode 100644 include/Nazara/Shader/ShaderVarVisitorExcept.hpp delete mode 100644 include/Nazara/Shader/ShaderVariables.hpp delete mode 100644 include/Nazara/Shader/ShaderVariables.inl delete mode 100644 src/Nazara/Shader/ShaderAst.cpp create mode 100644 src/Nazara/Shader/ShaderAstExpressionType.cpp rename src/Nazara/Shader/{ShaderAstVisitor.cpp => ShaderAstExpressionVisitor.cpp} (52%) create mode 100644 src/Nazara/Shader/ShaderAstExpressionVisitorExcept.cpp rename src/Nazara/Shader/{ShaderVarVisitor.cpp => ShaderAstStatementVisitor.cpp} (51%) create mode 100644 src/Nazara/Shader/ShaderAstStatementVisitorExcept.cpp delete mode 100644 src/Nazara/Shader/ShaderAstVisitorExcept.cpp delete mode 100644 src/Nazara/Shader/ShaderVarVisitorExcept.cpp delete mode 100644 src/Nazara/Shader/ShaderVariables.cpp diff --git a/include/Nazara/Renderer/RenderDevice.hpp b/include/Nazara/Renderer/RenderDevice.hpp index 8030c0f55..feb3a153c 100644 --- a/include/Nazara/Renderer/RenderDevice.hpp +++ b/include/Nazara/Renderer/RenderDevice.hpp @@ -16,6 +16,7 @@ #include #include #include +#include #include #include #include @@ -24,7 +25,6 @@ namespace Nz { class CommandPool; - class ShaderAst; class ShaderStage; class NAZARA_RENDERER_API RenderDevice @@ -39,7 +39,7 @@ namespace Nz virtual std::shared_ptr InstantiateRenderPass(std::vector attachments, std::vector subpassDescriptions, std::vector subpassDependencies) = 0; virtual std::shared_ptr InstantiateRenderPipeline(RenderPipelineInfo pipelineInfo) = 0; virtual std::shared_ptr InstantiateRenderPipelineLayout(RenderPipelineLayoutInfo pipelineLayoutInfo) = 0; - virtual std::shared_ptr InstantiateShaderStage(const ShaderAst& shaderAst, const ShaderWriter::States& states) = 0; + virtual std::shared_ptr InstantiateShaderStage(const ShaderAst::StatementPtr& shaderAst, const ShaderWriter::States& states) = 0; virtual std::shared_ptr InstantiateShaderStage(ShaderStageType type, ShaderLanguage lang, const void* source, std::size_t sourceSize) = 0; std::shared_ptr InstantiateShaderStage(ShaderStageType type, ShaderLanguage lang, const std::filesystem::path& sourcePath); virtual std::shared_ptr InstantiateTexture(const TextureInfo& params) = 0; diff --git a/include/Nazara/Shader.hpp b/include/Nazara/Shader.hpp index f8e41ff22..9c59a649d 100644 --- a/include/Nazara/Shader.hpp +++ b/include/Nazara/Shader.hpp @@ -32,23 +32,25 @@ #include #include #include -#include +#include #include +#include +#include +#include #include #include #include +#include +#include +#include #include #include -#include -#include #include #include #include -#include +#include +#include #include -#include -#include -#include #include #include #include @@ -58,6 +60,7 @@ #include #include #include +#include #include #endif // NAZARA_GLOBAL_SHADER_HPP diff --git a/include/Nazara/Shader/GlslWriter.hpp b/include/Nazara/Shader/GlslWriter.hpp index 5a6f7049d..fd3c299ce 100644 --- a/include/Nazara/Shader/GlslWriter.hpp +++ b/include/Nazara/Shader/GlslWriter.hpp @@ -9,9 +9,7 @@ #include #include -#include -#include -#include +#include #include #include #include @@ -19,7 +17,7 @@ namespace Nz { - class NAZARA_SHADER_API GlslWriter : public ShaderWriter, public ShaderVarVisitor, public ShaderAstVisitor + class NAZARA_SHADER_API GlslWriter : public ShaderWriter, public ShaderAst::AstRecursiveVisitor { public: struct Environment; @@ -30,7 +28,7 @@ namespace Nz GlslWriter(GlslWriter&&) = delete; ~GlslWriter() = default; - std::string Generate(const ShaderAst& shader, const States& conditions = {}); + std::string Generate(ShaderAst::StatementPtr& shader, const States& conditions = {}); void SetEnv(Environment environment); @@ -46,67 +44,45 @@ namespace Nz static const char* GetFlipYUniformName(); private: - void Append(ShaderExpressionType type); - void Append(ShaderNodes::BuiltinEntry builtin); - void Append(ShaderNodes::BasicType type); - void Append(ShaderNodes::MemoryLayout layout); + void Append(ShaderAst::ShaderExpressionType type); + void Append(ShaderAst::BuiltinEntry builtin); + void Append(ShaderAst::BasicType type); + void Append(ShaderAst::MemoryLayout layout); template void Append(const T& param); void AppendCommentSection(const std::string& section); - void AppendField(const std::string& structName, std::size_t* memberIndex, std::size_t remainingMembers); - void AppendFunction(const ShaderAst::Function& func); - void AppendFunctionPrototype(const ShaderAst::Function& func); + void AppendField(std::size_t scopeId, const std::string& structName, const std::string* memberIdentifier, std::size_t remainingMembers); void AppendLine(const std::string& txt = {}); - template void DeclareVariables(const ShaderAst& shader, const std::vector& variables, const std::string& keyword = {}, const std::string& section = {}); - void EnterScope(); void LeaveScope(); - using ShaderVarVisitor::Visit; - using ShaderAstVisitor::Visit; - void Visit(ShaderNodes::ExpressionPtr& expr, bool encloseIfRequired = false); - void Visit(ShaderNodes::AccessMember& node) override; - void Visit(ShaderNodes::AssignOp& node) override; - void Visit(ShaderNodes::Branch& node) override; - void Visit(ShaderNodes::BinaryOp& node) override; - void Visit(ShaderNodes::BuiltinVariable& var) override; - void Visit(ShaderNodes::Cast& node) override; - void Visit(ShaderNodes::ConditionalExpression& node) override; - void Visit(ShaderNodes::ConditionalStatement& node) override; - void Visit(ShaderNodes::Constant& node) override; - void Visit(ShaderNodes::DeclareVariable& node) override; - void Visit(ShaderNodes::Discard& node) override; - void Visit(ShaderNodes::ExpressionStatement& node) override; - void Visit(ShaderNodes::Identifier& node) override; - void Visit(ShaderNodes::InputVariable& var) override; - void Visit(ShaderNodes::IntrinsicCall& node) override; - void Visit(ShaderNodes::LocalVariable& var) override; - void Visit(ShaderNodes::NoOp& node) override; - void Visit(ShaderNodes::ParameterVariable& var) override; - void Visit(ShaderNodes::ReturnStatement& node) override; - void Visit(ShaderNodes::OutputVariable& var) override; - void Visit(ShaderNodes::Sample2D& node) override; - void Visit(ShaderNodes::StatementBlock& node) override; - void Visit(ShaderNodes::SwizzleOp& node) override; - void Visit(ShaderNodes::UniformVariable& var) override; + void Visit(ShaderAst::ExpressionPtr& expr, bool encloseIfRequired = false); - static bool HasExplicitBinding(const ShaderAst& shader); - static bool HasExplicitLocation(const ShaderAst& shader); + void Visit(ShaderAst::AccessMemberExpression& node) override; + void Visit(ShaderAst::AssignExpression& node) override; + void Visit(ShaderAst::BinaryExpression& node) override; + void Visit(ShaderAst::CastExpression& node) override; + void Visit(ShaderAst::ConditionalExpression& node) override; + void Visit(ShaderAst::ConstantExpression& node) override; + void Visit(ShaderAst::IdentifierExpression& node) override; + void Visit(ShaderAst::IntrinsicExpression& node) override; + void Visit(ShaderAst::SwizzleExpression& node) override; - struct Context - { - const ShaderAst* shader = nullptr; - const ShaderAst::Function* currentFunction = nullptr; - const States* states = nullptr; - }; + void Visit(ShaderAst::BranchStatement& node) override; + void Visit(ShaderAst::ConditionalStatement& node) override; + void Visit(ShaderAst::DeclareFunctionStatement& node) override; + void Visit(ShaderAst::DeclareVariableStatement& node) override; + void Visit(ShaderAst::DiscardStatement& node) override; + void Visit(ShaderAst::ExpressionStatement& node) override; + void Visit(ShaderAst::MultiStatement& node) override; + void Visit(ShaderAst::NoOpStatement& node) override; + void Visit(ShaderAst::ReturnStatement& node) override; - struct State - { - std::stringstream stream; - unsigned int indentLevel = 0; - }; + static bool HasExplicitBinding(ShaderAst::StatementPtr& shader); + static bool HasExplicitLocation(ShaderAst::StatementPtr& shader); + + struct State; - Context m_context; Environment m_environment; State* m_currentState; }; diff --git a/include/Nazara/Shader/GlslWriter.inl b/include/Nazara/Shader/GlslWriter.inl index fd5cbe094..1ecd13aee 100644 --- a/include/Nazara/Shader/GlslWriter.inl +++ b/include/Nazara/Shader/GlslWriter.inl @@ -3,130 +3,10 @@ // For conditions of distribution and use, see copyright notice in Config.hpp #include -#include #include namespace Nz { - template - void GlslWriter::Append(const T& param) - { - NazaraAssert(m_currentState, "This function should only be called while processing an AST"); - - m_currentState->stream << param; - } - - template - void GlslWriter::DeclareVariables(const ShaderAst& shader, const std::vector& variables, const std::string& keyword, const std::string& section) - { - if (!variables.empty()) - { - if (!section.empty()) - AppendCommentSection(section); - - for (const auto& var : variables) - { - if constexpr (std::is_same_v) - { - if (var.locationIndex) - { - Append("layout(location = "); - Append(*var.locationIndex); - Append(") "); - } - - if (!keyword.empty()) - { - Append(keyword); - Append(" "); - } - - Append(var.type); - Append(" "); - Append(var.name); - AppendLine(";"); - } - else if constexpr (std::is_same_v) - { - if (var.bindingIndex || var.memoryLayout) - { - Append("layout("); - - bool first = true; - if (var.bindingIndex) - { - if (!first) - Append(", "); - - Append("binding = "); - Append(*var.bindingIndex); - - first = false; - } - - if (var.memoryLayout) - { - if (!first) - Append(", "); - - Append(*var.memoryLayout); - - first = false; - } - - Append(") "); - } - - if (!keyword.empty()) - { - Append(keyword); - Append(" "); - } - - std::visit([&](auto&& arg) - { - using U = std::decay_t; - if constexpr (std::is_same_v) - { - Append(arg); - Append(" "); - Append(var.name); - } - else if constexpr (std::is_same_v) - { - const auto& structs = shader.GetStructs(); - auto it = std::find_if(structs.begin(), structs.end(), [&](const auto& s) { return s.name == arg; }); - if (it == structs.end()) - throw std::runtime_error("struct " + arg + " has not been defined"); - - const auto& s = *it; - - AppendLine(var.name + "_interface"); - AppendLine("{"); - for (const auto& m : s.members) - { - Append("\t"); - Append(m.type); - Append(" "); - Append(m.name); - AppendLine(";"); - } - Append("} "); - Append(var.name); - } - else - static_assert(AlwaysFalse::value, "non-exhaustive visitor"); - - }, var.type); - - AppendLine(";"); - AppendLine(); - } - } - - AppendLine(); - } - } } #include diff --git a/include/Nazara/Shader/ShaderAst.hpp b/include/Nazara/Shader/ShaderAst.hpp deleted file mode 100644 index f34b0ac9f..000000000 --- a/include/Nazara/Shader/ShaderAst.hpp +++ /dev/null @@ -1,130 +0,0 @@ -// Copyright (C) 2020 Jérôme Leclercq -// This file is part of the "Nazara Engine - Renderer module" -// For conditions of distribution and use, see copyright notice in Config.hpp - -#pragma once - -#ifndef NAZARA_SHADER_AST_HPP -#define NAZARA_SHADER_AST_HPP - -#include -#include -#include -#include -#include -#include -#include - -namespace Nz -{ - class NAZARA_SHADER_API ShaderAst - { - public: - struct Condition; - struct Function; - struct FunctionParameter; - struct InputOutput; - struct Struct; - struct StructMember; - struct Uniform; - struct VariableBase; - - inline ShaderAst(ShaderStageType shaderStage); - ShaderAst(const ShaderAst&) = default; - ShaderAst(ShaderAst&&) noexcept = default; - ~ShaderAst() = default; - - void AddCondition(std::string name); - void AddFunction(std::string name, ShaderNodes::StatementPtr statement, std::vector parameters = {}, ShaderExpressionType returnType = ShaderNodes::BasicType::Void); - void AddInput(std::string name, ShaderExpressionType type, std::optional locationIndex = {}); - void AddOutput(std::string name, ShaderExpressionType type, std::optional locationIndex = {}); - void AddStruct(std::string name, std::vector members); - void AddUniform(std::string name, ShaderExpressionType type, std::optional bindingIndex = {}, std::optional memoryLayout = {}); - - inline std::size_t FindConditionByName(const std::string_view& conditionName) const; - - inline const Condition& GetCondition(std::size_t i) const; - inline std::size_t GetConditionCount() const; - inline const std::vector& GetConditions() const; - inline const Function& GetFunction(std::size_t i) const; - inline std::size_t GetFunctionCount() const; - inline const std::vector& GetFunctions() const; - inline const InputOutput& GetInput(std::size_t i) const; - inline std::size_t GetInputCount() const; - inline const std::vector& GetInputs() const; - inline const InputOutput& GetOutput(std::size_t i) const; - inline std::size_t GetOutputCount() const; - inline const std::vector& GetOutputs() const; - inline ShaderStageType GetStage() const; - inline const Struct& GetStruct(std::size_t i) const; - inline std::size_t GetStructCount() const; - inline const std::vector& GetStructs() const; - inline const Uniform& GetUniform(std::size_t i) const; - inline std::size_t GetUniformCount() const; - inline const std::vector& GetUniforms() const; - - ShaderAst& operator=(const ShaderAst&) = default; - ShaderAst& operator=(ShaderAst&&) noexcept = default; - - struct Condition - { - std::string name; - }; - - struct VariableBase - { - std::string name; - ShaderExpressionType type; - }; - - struct FunctionParameter : VariableBase - { - }; - - struct Function - { - std::string name; - std::vector parameters; - ShaderExpressionType returnType; - ShaderNodes::StatementPtr statement; - }; - - struct InputOutput : VariableBase - { - std::optional locationIndex; - }; - - struct Uniform : VariableBase - { - std::optional bindingIndex; - std::optional memoryLayout; - }; - - struct Struct - { - std::string name; - std::vector members; - }; - - struct StructMember - { - std::string name; - ShaderExpressionType type; - }; - - static constexpr std::size_t InvalidCondition = std::numeric_limits::max(); - - private: - std::vector m_conditions; - std::vector m_functions; - std::vector m_inputs; - std::vector m_outputs; - std::vector m_structs; - std::vector m_uniforms; - ShaderStageType m_stage; - }; -} - -#include - -#endif // NAZARA_SHADER_AST_HPP diff --git a/include/Nazara/Shader/ShaderAst.inl b/include/Nazara/Shader/ShaderAst.inl deleted file mode 100644 index 64bce4583..000000000 --- a/include/Nazara/Shader/ShaderAst.inl +++ /dev/null @@ -1,128 +0,0 @@ -// Copyright (C) 2020 Jérôme Leclercq -// This file is part of the "Nazara Engine - Shader generator" -// For conditions of distribution and use, see copyright notice in Config.hpp - -#include -#include - -namespace Nz -{ - inline ShaderAst::ShaderAst(ShaderStageType shaderStage) : - m_stage(shaderStage) - { - } - - inline std::size_t ShaderAst::FindConditionByName(const std::string_view& conditionName) const - { - for (std::size_t i = 0; i < m_conditions.size(); ++i) - { - if (m_conditions[i].name == conditionName) - return i; - } - - return InvalidCondition; - } - - inline auto Nz::ShaderAst::GetCondition(std::size_t i) const -> const Condition& - { - assert(i < m_functions.size()); - return m_conditions[i]; - } - - inline std::size_t ShaderAst::GetConditionCount() const - { - return m_conditions.size(); - } - - inline auto ShaderAst::GetConditions() const -> const std::vector& - { - return m_conditions; - } - - inline auto ShaderAst::GetFunction(std::size_t i) const -> const Function& - { - assert(i < m_functions.size()); - return m_functions[i]; - } - - inline std::size_t ShaderAst::GetFunctionCount() const - { - return m_functions.size(); - } - - inline auto ShaderAst::GetFunctions() const -> const std::vector& - { - return m_functions; - } - - inline auto ShaderAst::GetInput(std::size_t i) const -> const InputOutput& - { - assert(i < m_inputs.size()); - return m_inputs[i]; - } - - inline std::size_t ShaderAst::GetInputCount() const - { - return m_inputs.size(); - } - - inline auto ShaderAst::GetInputs() const -> const std::vector& - { - return m_inputs; - } - - inline auto ShaderAst::GetOutput(std::size_t i) const -> const InputOutput& - { - assert(i < m_outputs.size()); - return m_outputs[i]; - } - - inline std::size_t ShaderAst::GetOutputCount() const - { - return m_outputs.size(); - } - - inline auto ShaderAst::GetOutputs() const -> const std::vector& - { - return m_outputs; - } - - inline ShaderStageType ShaderAst::GetStage() const - { - return m_stage; - } - - inline auto ShaderAst::GetStruct(std::size_t i) const -> const Struct& - { - assert(i < m_structs.size()); - return m_structs[i]; - } - - inline std::size_t ShaderAst::GetStructCount() const - { - return m_structs.size(); - } - - inline auto ShaderAst::GetStructs() const -> const std::vector& - { - return m_structs; - } - - inline auto ShaderAst::GetUniform(std::size_t i) const -> const Uniform& - { - assert(i < m_uniforms.size()); - return m_uniforms[i]; - } - - inline std::size_t ShaderAst::GetUniformCount() const - { - return m_uniforms.size(); - } - - inline auto ShaderAst::GetUniforms() const -> const std::vector& - { - return m_uniforms; - } -} - -#include diff --git a/include/Nazara/Shader/ShaderAstCache.hpp b/include/Nazara/Shader/ShaderAstCache.hpp new file mode 100644 index 000000000..595b6b410 --- /dev/null +++ b/include/Nazara/Shader/ShaderAstCache.hpp @@ -0,0 +1,48 @@ +// Copyright (C) 2020 Jérôme Leclercq +// This file is part of the "Nazara Engine - Shader generator" +// For conditions of distribution and use, see copyright notice in Config.hpp + +#pragma once + +#ifndef NAZARA_SHADERASTCACHE_HPP +#define NAZARA_SHADERASTCACHE_HPP + +#include +#include +#include +#include + +namespace Nz::ShaderAst +{ + struct AstCache + { + struct Variable + { + ShaderExpressionType type; + }; + + struct Identifier + { + std::string name; + std::variant value; + }; + + struct Scope + { + std::optional parentScopeIndex; + std::vector identifiers; + }; + + inline const Identifier* FindIdentifier(std::size_t startingScopeId, const std::string& identifierName) const; + inline std::size_t GetScopeId(const Node* node) const; + + ShaderStageType stageType = ShaderStageType::Undefined; + std::unordered_map nodeExpressionType; + std::unordered_map scopeIdByNode; + std::vector scopes; + }; +} + +#include + +#endif diff --git a/include/Nazara/Shader/ShaderAstCache.inl b/include/Nazara/Shader/ShaderAstCache.inl new file mode 100644 index 000000000..48cd769f9 --- /dev/null +++ b/include/Nazara/Shader/ShaderAstCache.inl @@ -0,0 +1,37 @@ +// Copyright (C) 2020 Jérôme Leclercq +// This file is part of the "Nazara Engine - Shader generator" +// For conditions of distribution and use, see copyright notice in Config.hpp + +#include +#include + +namespace Nz::ShaderAst +{ + inline auto AstCache::FindIdentifier(std::size_t startingScopeId, const std::string& identifierName) const -> const Identifier* + { + assert(startingScopeId < scopes.size()); + + std::optional scopeId = startingScopeId; + do + { + const auto& scope = scopes[*scopeId]; + auto it = std::find_if(scope.identifiers.rbegin(), scope.identifiers.rend(), [&](const auto& identifier) { return identifier.name == identifierName; }); + if (it != scope.identifiers.rend()) + return &*it; + + scopeId = scope.parentScopeIndex; + } while (scopeId); + + return nullptr; + } + + inline std::size_t AstCache::GetScopeId(const Node* node) const + { + auto it = scopeIdByNode.find(node); + assert(it == scopeIdByNode.end()); + + return it->second; + } +} + +#include diff --git a/include/Nazara/Shader/ShaderAstCloner.hpp b/include/Nazara/Shader/ShaderAstCloner.hpp index b777e2ba6..23e69234a 100644 --- a/include/Nazara/Shader/ShaderAstCloner.hpp +++ b/include/Nazara/Shader/ShaderAstCloner.hpp @@ -9,70 +9,62 @@ #include #include -#include -#include +#include +#include #include -namespace Nz +namespace Nz::ShaderAst { - class NAZARA_SHADER_API ShaderAstCloner : public ShaderAstVisitor, public ShaderVarVisitor + class NAZARA_SHADER_API AstCloner : public AstExpressionVisitor, public AstStatementVisitor { public: - ShaderAstCloner() = default; - ShaderAstCloner(const ShaderAstCloner&) = delete; - ShaderAstCloner(ShaderAstCloner&&) = delete; - ~ShaderAstCloner() = default; + AstCloner() = default; + AstCloner(const AstCloner&) = delete; + AstCloner(AstCloner&&) = delete; + ~AstCloner() = default; - ShaderNodes::StatementPtr Clone(const ShaderNodes::StatementPtr& statement); + ExpressionPtr Clone(ExpressionPtr& statement); + StatementPtr Clone(StatementPtr& statement); - ShaderAstCloner& operator=(const ShaderAstCloner&) = delete; - ShaderAstCloner& operator=(ShaderAstCloner&&) = delete; + AstCloner& operator=(const AstCloner&) = delete; + AstCloner& operator=(AstCloner&&) = delete; protected: - ShaderNodes::ExpressionPtr CloneExpression(const ShaderNodes::ExpressionPtr& expr); - ShaderNodes::StatementPtr CloneStatement(const ShaderNodes::StatementPtr& statement); - ShaderNodes::VariablePtr CloneVariable(const ShaderNodes::VariablePtr& statement); + ExpressionPtr CloneExpression(ExpressionPtr& expr); + StatementPtr CloneStatement(StatementPtr& statement); - using ShaderAstVisitor::Visit; - void Visit(ShaderNodes::AccessMember& node) override; - void Visit(ShaderNodes::AssignOp& node) override; - void Visit(ShaderNodes::BinaryOp& node) override; - void Visit(ShaderNodes::Branch& node) override; - void Visit(ShaderNodes::Cast& node) override; - void Visit(ShaderNodes::ConditionalExpression& node) override; - void Visit(ShaderNodes::ConditionalStatement& node) override; - void Visit(ShaderNodes::Constant& node) override; - void Visit(ShaderNodes::DeclareVariable& node) override; - void Visit(ShaderNodes::Discard& node) override; - void Visit(ShaderNodes::ExpressionStatement& node) override; - void Visit(ShaderNodes::Identifier& node) override; - void Visit(ShaderNodes::IntrinsicCall& node) override; - void Visit(ShaderNodes::NoOp& node) override; - void Visit(ShaderNodes::ReturnStatement& node) override; - void Visit(ShaderNodes::Sample2D& node) override; - void Visit(ShaderNodes::StatementBlock& node) override; - void Visit(ShaderNodes::SwizzleOp& node) override; + using AstExpressionVisitor::Visit; + using AstStatementVisitor::Visit; - using ShaderVarVisitor::Visit; - void Visit(ShaderNodes::BuiltinVariable& var) override; - void Visit(ShaderNodes::InputVariable& var) override; - void Visit(ShaderNodes::LocalVariable& var) override; - void Visit(ShaderNodes::OutputVariable& var) override; - void Visit(ShaderNodes::ParameterVariable& var) override; - void Visit(ShaderNodes::UniformVariable& var) override; + void Visit(AccessMemberExpression& node) override; + void Visit(AssignExpression& node) override; + void Visit(BinaryExpression& node) override; + void Visit(CastExpression& node) override; + void Visit(ConditionalExpression& node) override; + void Visit(ConstantExpression& node) override; + void Visit(IdentifierExpression& node) override; + void Visit(IntrinsicExpression& node) override; + void Visit(SwizzleExpression& node) override; + void Visit(BranchStatement& node) override; + void Visit(ConditionalStatement& node) override; + void Visit(DeclareFunctionStatement& node) override; + void Visit(DeclareStructStatement& node) override; + void Visit(DeclareVariableStatement& node) override; + void Visit(DiscardStatement& node) override; + void Visit(ExpressionStatement& node) override; + void Visit(MultiStatement& node) override; + void Visit(NoOpStatement& node) override; + void Visit(ReturnStatement& node) override; - void PushExpression(ShaderNodes::ExpressionPtr expression); - void PushStatement(ShaderNodes::StatementPtr statement); - void PushVariable(ShaderNodes::VariablePtr variable); + void PushExpression(ExpressionPtr expression); + void PushStatement(StatementPtr statement); - ShaderNodes::ExpressionPtr PopExpression(); - ShaderNodes::StatementPtr PopStatement(); - ShaderNodes::VariablePtr PopVariable(); + ExpressionPtr PopExpression(); + StatementPtr PopStatement(); private: - std::vector m_expressionStack; - std::vector m_statementStack; - std::vector m_variableStack; + std::vector m_expressionStack; + std::vector m_statementStack; }; } diff --git a/include/Nazara/Shader/ShaderAstExpressionType.hpp b/include/Nazara/Shader/ShaderAstExpressionType.hpp new file mode 100644 index 000000000..4ad7731e8 --- /dev/null +++ b/include/Nazara/Shader/ShaderAstExpressionType.hpp @@ -0,0 +1,57 @@ +// Copyright (C) 2020 Jérôme Leclercq +// This file is part of the "Nazara Engine - Shader generator" +// For conditions of distribution and use, see copyright notice in Config.hpp + +#pragma once + +#ifndef NAZARA_SHADERASTEXPRESSIONTYPE_HPP +#define NAZARA_SHADERASTEXPRESSIONTYPE_HPP + +#include +#include +#include +#include +#include + +namespace Nz::ShaderAst +{ + struct AstCache; + + class NAZARA_SHADER_API ExpressionTypeVisitor : public AstExpressionVisitor + { + public: + ExpressionTypeVisitor() = default; + ExpressionTypeVisitor(const ExpressionTypeVisitor&) = delete; + ExpressionTypeVisitor(ExpressionTypeVisitor&&) = delete; + ~ExpressionTypeVisitor() = default; + + ShaderExpressionType GetExpressionType(Expression& expression, AstCache* cache); + + ExpressionTypeVisitor& operator=(const ExpressionTypeVisitor&) = delete; + ExpressionTypeVisitor& operator=(ExpressionTypeVisitor&&) = delete; + + private: + ShaderExpressionType GetExpressionTypeInternal(Expression& expression); + + void Visit(Expression& expression); + + void Visit(AccessMemberExpression& node) override; + void Visit(AssignExpression& node) override; + void Visit(BinaryExpression& node) override; + void Visit(CastExpression& node) override; + void Visit(ConditionalExpression& node) override; + void Visit(ConstantExpression& node) override; + void Visit(IdentifierExpression& node) override; + void Visit(IntrinsicExpression& node) override; + void Visit(SwizzleExpression& node) override; + + AstCache* m_cache; + std::optional m_lastExpressionType; + }; + + inline ShaderExpressionType GetExpressionType(Expression& expression, AstCache* cache = nullptr); +} + +#include + +#endif diff --git a/include/Nazara/Shader/ShaderAstExpressionType.inl b/include/Nazara/Shader/ShaderAstExpressionType.inl new file mode 100644 index 000000000..f71146200 --- /dev/null +++ b/include/Nazara/Shader/ShaderAstExpressionType.inl @@ -0,0 +1,17 @@ +// Copyright (C) 2020 Jérôme Leclercq +// This file is part of the "Nazara Engine - Shader generator" +// For conditions of distribution and use, see copyright notice in Config.hpp + +#include +#include + +namespace Nz::ShaderAst +{ + inline ShaderExpressionType GetExpressionType(Expression& expression, AstCache* cache) + { + ExpressionTypeVisitor visitor; + return visitor.GetExpressionType(expression, cache); + } +} + +#include diff --git a/include/Nazara/Shader/ShaderAstExpressionVisitor.hpp b/include/Nazara/Shader/ShaderAstExpressionVisitor.hpp new file mode 100644 index 000000000..83e8b0271 --- /dev/null +++ b/include/Nazara/Shader/ShaderAstExpressionVisitor.hpp @@ -0,0 +1,32 @@ +// Copyright (C) 2020 Jérôme Leclercq +// This file is part of the "Nazara Engine - Shader generator" +// For conditions of distribution and use, see copyright notice in Config.hpp + +#pragma once + +#ifndef NAZARA_SHADERASTEXPRESSIONVISITOR_HPP +#define NAZARA_SHADERASTEXPRESSIONVISITOR_HPP + +#include +#include +#include + +namespace Nz::ShaderAst +{ + class NAZARA_SHADER_API AstExpressionVisitor + { + public: + AstExpressionVisitor() = default; + AstExpressionVisitor(const AstExpressionVisitor&) = delete; + AstExpressionVisitor(AstExpressionVisitor&&) = delete; + virtual ~AstExpressionVisitor(); + +#define NAZARA_SHADERAST_EXPRESSION(NodeType) virtual void Visit(NodeType& node) = 0; +#include + + AstExpressionVisitor& operator=(const AstExpressionVisitor&) = delete; + AstExpressionVisitor& operator=(AstExpressionVisitor&&) = delete; + }; +} + +#endif diff --git a/include/Nazara/Shader/ShaderAstExpressionVisitorExcept.hpp b/include/Nazara/Shader/ShaderAstExpressionVisitorExcept.hpp new file mode 100644 index 000000000..f28e98c78 --- /dev/null +++ b/include/Nazara/Shader/ShaderAstExpressionVisitorExcept.hpp @@ -0,0 +1,26 @@ +// Copyright (C) 2020 Jérôme Leclercq +// This file is part of the "Nazara Engine - Shader generator" +// For conditions of distribution and use, see copyright notice in Config.hpp + +#pragma once + +#ifndef NAZARA_SHADERASTEXPRESSIONVISITOREXCEPT_HPP +#define NAZARA_SHADERASTEXPRESSIONVISITOREXCEPT_HPP + +#include +#include +#include + +namespace Nz::ShaderAst +{ + class NAZARA_SHADER_API ExpressionVisitorExcept : public AstExpressionVisitor + { + public: + using AstExpressionVisitor::Visit; + +#define NAZARA_SHADERAST_EXPRESSION(Node) void Visit(ShaderAst::Node& node) override; +#include + }; +} + +#endif diff --git a/include/Nazara/Shader/ShaderAstNodes.hpp b/include/Nazara/Shader/ShaderAstNodes.hpp new file mode 100644 index 000000000..7bbce68a9 --- /dev/null +++ b/include/Nazara/Shader/ShaderAstNodes.hpp @@ -0,0 +1,53 @@ +// Copyright (C) 2020 Jérôme Leclercq +// This file is part of the "Nazara Engine - Renderer module" +// For conditions of distribution and use, see copyright notice in Config.hpp + +#if !defined(NAZARA_SHADERAST_NODE) && !defined(NAZARA_SHADERAST_EXPRESSION) && !defined(NAZARA_SHADERAST_STATEMENT) +#error You must define NAZARA_SHADERAST_NODE or NAZARA_SHADERAST_EXPRESSION or NAZARA_SHADERAST_STATEMENT before including this file +#endif + +#ifndef NAZARA_SHADERAST_NODE +#define NAZARA_SHADERAST_NODE(X) +#endif + +#ifndef NAZARA_SHADERAST_NODE_LAST +#define NAZARA_SHADERAST_NODE_LAST(X) +#endif + +#ifndef NAZARA_SHADERAST_EXPRESSION +#define NAZARA_SHADERAST_EXPRESSION(X) NAZARA_SHADERAST_NODE(X) +#endif + +#ifndef NAZARA_SHADERAST_STATEMENT +#define NAZARA_SHADERAST_STATEMENT(X) NAZARA_SHADERAST_NODE(X) +#endif + +#ifndef NAZARA_SHADERAST_STATEMENT_LAST +#define NAZARA_SHADERAST_STATEMENT_LAST(X) NAZARA_SHADERAST_STATEMENT(X) +#endif + +NAZARA_SHADERAST_EXPRESSION(AccessMemberExpression) +NAZARA_SHADERAST_EXPRESSION(AssignExpression) +NAZARA_SHADERAST_EXPRESSION(BinaryExpression) +NAZARA_SHADERAST_EXPRESSION(CastExpression) +NAZARA_SHADERAST_EXPRESSION(ConditionalExpression) +NAZARA_SHADERAST_EXPRESSION(ConstantExpression) +NAZARA_SHADERAST_EXPRESSION(IdentifierExpression) +NAZARA_SHADERAST_EXPRESSION(IntrinsicExpression) +NAZARA_SHADERAST_EXPRESSION(SwizzleExpression) +NAZARA_SHADERAST_STATEMENT(BranchStatement) +NAZARA_SHADERAST_STATEMENT(ConditionalStatement) +NAZARA_SHADERAST_STATEMENT(DeclareFunctionStatement) +NAZARA_SHADERAST_STATEMENT(DeclareStructStatement) +NAZARA_SHADERAST_STATEMENT(DeclareVariableStatement) +NAZARA_SHADERAST_STATEMENT(DiscardStatement) +NAZARA_SHADERAST_STATEMENT(ExpressionStatement) +NAZARA_SHADERAST_STATEMENT(MultiStatement) +NAZARA_SHADERAST_STATEMENT(NoOpStatement) +NAZARA_SHADERAST_STATEMENT_LAST(ReturnStatement) + +#undef NAZARA_SHADERAST_EXPRESSION +#undef NAZARA_SHADERAST_NODE +#undef NAZARA_SHADERAST_NODE_LAST +#undef NAZARA_SHADERAST_STATEMENT +#undef NAZARA_SHADERAST_STATEMENT_LAST diff --git a/include/Nazara/Shader/ShaderAstOptimizer.hpp b/include/Nazara/Shader/ShaderAstOptimizer.hpp index d7d56b7cf..83d5fa499 100644 --- a/include/Nazara/Shader/ShaderAstOptimizer.hpp +++ b/include/Nazara/Shader/ShaderAstOptimizer.hpp @@ -12,35 +12,32 @@ #include #include -namespace Nz +namespace Nz::ShaderAst { - class ShaderAst; - - class NAZARA_SHADER_API ShaderAstOptimizer : public ShaderAstCloner + class NAZARA_SHADER_API AstOptimizer : public AstCloner { public: - ShaderAstOptimizer() = default; - ShaderAstOptimizer(const ShaderAstOptimizer&) = delete; - ShaderAstOptimizer(ShaderAstOptimizer&&) = delete; - ~ShaderAstOptimizer() = default; + AstOptimizer() = default; + AstOptimizer(const AstOptimizer&) = delete; + AstOptimizer(AstOptimizer&&) = delete; + ~AstOptimizer() = default; - ShaderNodes::StatementPtr Optimise(const ShaderNodes::StatementPtr& statement); - ShaderNodes::StatementPtr Optimise(const ShaderNodes::StatementPtr& statement, const ShaderAst& shader, UInt64 enabledConditions); + StatementPtr Optimise(StatementPtr& statement); + StatementPtr Optimise(StatementPtr& statement, UInt64 enabledConditions); - ShaderAstOptimizer& operator=(const ShaderAstOptimizer&) = delete; - ShaderAstOptimizer& operator=(ShaderAstOptimizer&&) = delete; + AstOptimizer& operator=(const AstOptimizer&) = delete; + AstOptimizer& operator=(AstOptimizer&&) = delete; protected: - using ShaderAstCloner::Visit; - void Visit(ShaderNodes::BinaryOp& node) override; - void Visit(ShaderNodes::Branch& node) override; - void Visit(ShaderNodes::ConditionalExpression& node) override; - void Visit(ShaderNodes::ConditionalStatement& node) override; + using AstCloner::Visit; + void Visit(BinaryExpression& node) override; + void Visit(ConditionalExpression& node) override; + void Visit(BranchStatement& node) override; + void Visit(ConditionalStatement& node) override; - template void PropagateConstant(const std::shared_ptr& lhs, const std::shared_ptr& rhs); + template void PropagateConstant(std::unique_ptr&& lhs, std::unique_ptr&& rhs); private: - const ShaderAst* m_shaderAst; UInt64 m_enabledConditions; }; } diff --git a/include/Nazara/Shader/ShaderAstRecursiveVisitor.hpp b/include/Nazara/Shader/ShaderAstRecursiveVisitor.hpp index 1b7e1eebd..68251ddc0 100644 --- a/include/Nazara/Shader/ShaderAstRecursiveVisitor.hpp +++ b/include/Nazara/Shader/ShaderAstRecursiveVisitor.hpp @@ -9,36 +9,37 @@ #include #include -#include +#include +#include -namespace Nz +namespace Nz::ShaderAst { - class NAZARA_SHADER_API ShaderAstRecursiveVisitor : public ShaderAstVisitor + class NAZARA_SHADER_API AstRecursiveVisitor : public AstExpressionVisitor, public AstStatementVisitor { public: - ShaderAstRecursiveVisitor() = default; - ~ShaderAstRecursiveVisitor() = default; + AstRecursiveVisitor() = default; + ~AstRecursiveVisitor() = default; - using ShaderAstVisitor::Visit; + void Visit(AccessMemberExpression& node) override; + void Visit(AssignExpression& node) override; + void Visit(BinaryExpression& node) override; + void Visit(CastExpression& node) override; + void Visit(ConditionalExpression& node) override; + void Visit(ConstantExpression& node) override; + void Visit(IdentifierExpression& node) override; + void Visit(IntrinsicExpression& node) override; + void Visit(SwizzleExpression& node) override; - void Visit(ShaderNodes::AccessMember& node) override; - void Visit(ShaderNodes::AssignOp& node) override; - void Visit(ShaderNodes::BinaryOp& node) override; - void Visit(ShaderNodes::Branch& node) override; - void Visit(ShaderNodes::Cast& node) override; - void Visit(ShaderNodes::ConditionalExpression& node) override; - void Visit(ShaderNodes::ConditionalStatement& node) override; - void Visit(ShaderNodes::Constant& node) override; - void Visit(ShaderNodes::DeclareVariable& node) override; - void Visit(ShaderNodes::Discard& node) override; - void Visit(ShaderNodes::ExpressionStatement& node) override; - void Visit(ShaderNodes::Identifier& node) override; - void Visit(ShaderNodes::IntrinsicCall& node) override; - void Visit(ShaderNodes::NoOp& node) override; - void Visit(ShaderNodes::ReturnStatement& node) override; - void Visit(ShaderNodes::Sample2D& node) override; - void Visit(ShaderNodes::StatementBlock& node) override; - void Visit(ShaderNodes::SwizzleOp& node) override; + void Visit(BranchStatement& node) override; + void Visit(ConditionalStatement& node) override; + void Visit(DeclareFunctionStatement& node) override; + void Visit(DeclareStructStatement& node) override; + void Visit(DeclareVariableStatement& node) override; + void Visit(DiscardStatement& node) override; + void Visit(ExpressionStatement& node) override; + void Visit(MultiStatement& node) override; + void Visit(NoOpStatement& node) override; + void Visit(ReturnStatement& node) override; }; } diff --git a/include/Nazara/Shader/ShaderAstSerializer.hpp b/include/Nazara/Shader/ShaderAstSerializer.hpp index b3e674fae..fe09ea604 100644 --- a/include/Nazara/Shader/ShaderAstSerializer.hpp +++ b/include/Nazara/Shader/ShaderAstSerializer.hpp @@ -11,40 +11,38 @@ #include #include #include -#include #include -#include -namespace Nz +namespace Nz::ShaderAst { - class NAZARA_SHADER_API ShaderAstSerializerBase + class NAZARA_SHADER_API AstSerializerBase { public: - ShaderAstSerializerBase() = default; - ShaderAstSerializerBase(const ShaderAstSerializerBase&) = delete; - ShaderAstSerializerBase(ShaderAstSerializerBase&&) = delete; - ~ShaderAstSerializerBase() = default; + AstSerializerBase() = default; + AstSerializerBase(const AstSerializerBase&) = delete; + AstSerializerBase(AstSerializerBase&&) = delete; + ~AstSerializerBase() = default; - void Serialize(ShaderNodes::AccessMember& node); - void Serialize(ShaderNodes::AssignOp& node); - void Serialize(ShaderNodes::BinaryOp& node); - void Serialize(ShaderNodes::BuiltinVariable& var); - void Serialize(ShaderNodes::Branch& node); - void Serialize(ShaderNodes::Cast& node); - void Serialize(ShaderNodes::ConditionalExpression& node); - void Serialize(ShaderNodes::ConditionalStatement& node); - void Serialize(ShaderNodes::Constant& node); - void Serialize(ShaderNodes::DeclareVariable& node); - void Serialize(ShaderNodes::Discard& node); - void Serialize(ShaderNodes::ExpressionStatement& node); - void Serialize(ShaderNodes::Identifier& node); - void Serialize(ShaderNodes::IntrinsicCall& node); - void Serialize(ShaderNodes::NamedVariable& var); - void Serialize(ShaderNodes::NoOp& node); - void Serialize(ShaderNodes::ReturnStatement& node); - void Serialize(ShaderNodes::Sample2D& node); - void Serialize(ShaderNodes::StatementBlock& node); - void Serialize(ShaderNodes::SwizzleOp& node); + void Serialize(AccessMemberExpression& node); + void Serialize(AssignExpression& node); + void Serialize(BinaryExpression& node); + void Serialize(CastExpression& node); + void Serialize(ConditionalExpression& node); + void Serialize(ConstantExpression& node); + void Serialize(IdentifierExpression& node); + void Serialize(IntrinsicExpression& node); + void Serialize(SwizzleExpression& node); + + void Serialize(BranchStatement& node); + void Serialize(ConditionalStatement& node); + void Serialize(DeclareFunctionStatement& node); + void Serialize(DeclareStructStatement& node); + void Serialize(DeclareVariableStatement& node); + void Serialize(DiscardStatement& node); + void Serialize(ExpressionStatement& node); + void Serialize(MultiStatement& node); + void Serialize(NoOpStatement& node); + void Serialize(ReturnStatement& node); protected: template void Container(T& container); @@ -54,8 +52,8 @@ namespace Nz virtual bool IsWriting() const = 0; - virtual void Node(ShaderNodes::NodePtr& node) = 0; - template void Node(std::shared_ptr& node); + virtual void Node(ExpressionPtr& node) = 0; + virtual void Node(StatementPtr& node) = 0; virtual void Type(ShaderExpressionType& type) = 0; @@ -74,23 +72,20 @@ namespace Nz virtual void Value(UInt32& val) = 0; virtual void Value(UInt64& val) = 0; inline void SizeT(std::size_t& val); - - virtual void Variable(ShaderNodes::VariablePtr& var) = 0; - template void Variable(std::shared_ptr& var); }; - class NAZARA_SHADER_API ShaderAstSerializer final : public ShaderAstSerializerBase + class NAZARA_SHADER_API ShaderAstSerializer final : public AstSerializerBase { public: inline ShaderAstSerializer(ByteStream& stream); ~ShaderAstSerializer() = default; - void Serialize(const ShaderAst& shader); + void Serialize(StatementPtr& shader); private: bool IsWriting() const override; - void Node(const ShaderNodes::NodePtr& node); - void Node(ShaderNodes::NodePtr& node) override; + void Node(ExpressionPtr& node) override; + void Node(StatementPtr& node) override; void Type(ShaderExpressionType& type) override; void Value(bool& val) override; void Value(float& val) override; @@ -106,22 +101,22 @@ namespace Nz void Value(UInt16& val) override; void Value(UInt32& val) override; void Value(UInt64& val) override; - void Variable(ShaderNodes::VariablePtr& var) override; ByteStream& m_stream; }; - class NAZARA_SHADER_API ShaderAstUnserializer final : public ShaderAstSerializerBase + class NAZARA_SHADER_API ShaderAstUnserializer final : public AstSerializerBase { public: ShaderAstUnserializer(ByteStream& stream); ~ShaderAstUnserializer() = default; - ShaderAst Unserialize(); + StatementPtr Unserialize(); private: bool IsWriting() const override; - void Node(ShaderNodes::NodePtr& node) override; + void Node(ExpressionPtr& node) override; + void Node(StatementPtr& node) override; void Type(ShaderExpressionType& type) override; void Value(bool& val) override; void Value(float& val) override; @@ -137,14 +132,13 @@ namespace Nz void Value(UInt16& val) override; void Value(UInt32& val) override; void Value(UInt64& val) override; - void Variable(ShaderNodes::VariablePtr& var) override; ByteStream& m_stream; }; - NAZARA_SHADER_API ByteArray SerializeShader(const ShaderAst& shader); - inline ShaderAst UnserializeShader(const void* data, std::size_t size); - NAZARA_SHADER_API ShaderAst UnserializeShader(ByteStream& stream); + NAZARA_SHADER_API ByteArray SerializeShader(StatementPtr& shader); + inline StatementPtr UnserializeShader(const void* data, std::size_t size); + NAZARA_SHADER_API StatementPtr UnserializeShader(ByteStream& stream); } #include diff --git a/include/Nazara/Shader/ShaderAstSerializer.inl b/include/Nazara/Shader/ShaderAstSerializer.inl index 6b61c1197..c1b2d41c6 100644 --- a/include/Nazara/Shader/ShaderAstSerializer.inl +++ b/include/Nazara/Shader/ShaderAstSerializer.inl @@ -5,10 +5,10 @@ #include #include -namespace Nz +namespace Nz::ShaderAst { template - void ShaderAstSerializerBase::Container(T& container) + void AstSerializerBase::Container(T& container) { bool isWriting = IsWriting(); @@ -23,7 +23,7 @@ namespace Nz template - void ShaderAstSerializerBase::Enum(T& enumVal) + void AstSerializerBase::Enum(T& enumVal) { bool isWriting = IsWriting(); @@ -37,7 +37,7 @@ namespace Nz } template - void ShaderAstSerializerBase::OptEnum(std::optional& optVal) + void AstSerializerBase::OptEnum(std::optional& optVal) { bool isWriting = IsWriting(); @@ -55,7 +55,7 @@ namespace Nz } template - void ShaderAstSerializerBase::OptVal(std::optional& optVal) + void AstSerializerBase::OptVal(std::optional& optVal) { bool isWriting = IsWriting(); @@ -77,21 +77,7 @@ namespace Nz } } - template - void ShaderAstSerializerBase::Node(std::shared_ptr& node) - { - bool isWriting = IsWriting(); - - ShaderNodes::NodePtr value; - if (isWriting) - value = node; - - Node(value); - if (!isWriting) - node = std::static_pointer_cast(value); - } - - inline void ShaderAstSerializerBase::SizeT(std::size_t& val) + inline void AstSerializerBase::SizeT(std::size_t& val) { bool isWriting = IsWriting(); @@ -105,20 +91,6 @@ namespace Nz val = static_cast(fixedVal); } - template - void ShaderAstSerializerBase::Variable(std::shared_ptr& var) - { - bool isWriting = IsWriting(); - - ShaderNodes::VariablePtr value; - if (isWriting) - value = var; - - Variable(value); - if (!isWriting) - var = std::static_pointer_cast(value); - } - inline ShaderAstSerializer::ShaderAstSerializer(ByteStream& stream) : m_stream(stream) { @@ -129,7 +101,7 @@ namespace Nz { } - inline ShaderAst UnserializeShader(const void* data, std::size_t size) + inline StatementPtr UnserializeShader(const void* data, std::size_t size) { ByteStream byteStream(data, size); return UnserializeShader(byteStream); diff --git a/include/Nazara/Shader/ShaderAstStatementVisitor.hpp b/include/Nazara/Shader/ShaderAstStatementVisitor.hpp new file mode 100644 index 000000000..2da7e28a1 --- /dev/null +++ b/include/Nazara/Shader/ShaderAstStatementVisitor.hpp @@ -0,0 +1,32 @@ +// Copyright (C) 2020 Jérôme Leclercq +// This file is part of the "Nazara Engine - Shader generator" +// For conditions of distribution and use, see copyright notice in Config.hpp + +#pragma once + +#ifndef NAZARA_SHADERASTSTATEMENTVISITOR_HPP +#define NAZARA_SHADERASTSTATEMENTVISITOR_HPP + +#include +#include +#include + +namespace Nz::ShaderAst +{ + class NAZARA_SHADER_API AstStatementVisitor + { + public: + AstStatementVisitor() = default; + AstStatementVisitor(const AstStatementVisitor&) = delete; + AstStatementVisitor(AstStatementVisitor&&) = delete; + virtual ~AstStatementVisitor(); + +#define NAZARA_SHADERAST_STATEMENT(NodeType) virtual void Visit(ShaderAst::NodeType& node) = 0; +#include + + AstStatementVisitor& operator=(const AstStatementVisitor&) = delete; + AstStatementVisitor& operator=(AstStatementVisitor&&) = delete; + }; +} + +#endif diff --git a/include/Nazara/Shader/ShaderAstStatementVisitorExcept.hpp b/include/Nazara/Shader/ShaderAstStatementVisitorExcept.hpp new file mode 100644 index 000000000..d5a85416f --- /dev/null +++ b/include/Nazara/Shader/ShaderAstStatementVisitorExcept.hpp @@ -0,0 +1,26 @@ +// Copyright (C) 2020 Jérôme Leclercq +// This file is part of the "Nazara Engine - Shader generator" +// For conditions of distribution and use, see copyright notice in Config.hpp + +#pragma once + +#ifndef NAZARA_SHADERASTSTATEMENTVISITOREXCEPT_HPP +#define NAZARA_SHADERASTSTATEMENTVISITOREXCEPT_HPP + +#include +#include +#include + +namespace Nz::ShaderAst +{ + class NAZARA_SHADER_API StatementVisitorExcept : public AstStatementVisitor + { + public: + using AstStatementVisitor::Visit; + +#define NAZARA_SHADERAST_STATEMENT(Node) void Visit(ShaderAst::Node& node) override; +#include + }; +} + +#endif diff --git a/include/Nazara/Shader/ShaderExpressionType.hpp b/include/Nazara/Shader/ShaderAstTypes.hpp similarity index 56% rename from include/Nazara/Shader/ShaderExpressionType.hpp rename to include/Nazara/Shader/ShaderAstTypes.hpp index 6d5385121..e1f50d4e7 100644 --- a/include/Nazara/Shader/ShaderExpressionType.hpp +++ b/include/Nazara/Shader/ShaderAstTypes.hpp @@ -4,17 +4,30 @@ #pragma once -#ifndef NAZARA_SHADER_EXPRESSIONTYPE_HPP -#define NAZARA_SHADER_EXPRESSIONTYPE_HPP +#ifndef NAZARA_SHADER_ASTTYPES_HPP +#define NAZARA_SHADER_ASTTYPES_HPP #include #include #include #include +#include -namespace Nz +namespace Nz::ShaderAst { - using ShaderExpressionType = std::variant; + using ShaderExpressionType = std::variant; + + struct StructDescription + { + struct StructMember + { + std::string name; + ShaderExpressionType type; + }; + + std::string name; + std::vector members; + }; inline bool IsBasicType(const ShaderExpressionType& type); inline bool IsMatrixType(const ShaderExpressionType& type); @@ -22,6 +35,6 @@ namespace Nz inline bool IsStructType(const ShaderExpressionType& type); } -#include +#include -#endif // NAZARA_SHADER_EXPRESSIONTYPE_HPP +#endif // NAZARA_SHADER_ASTTYPES_HPP diff --git a/include/Nazara/Shader/ShaderExpressionType.inl b/include/Nazara/Shader/ShaderAstTypes.inl similarity index 89% rename from include/Nazara/Shader/ShaderExpressionType.inl rename to include/Nazara/Shader/ShaderAstTypes.inl index 8f72fb376..6eed4a945 100644 --- a/include/Nazara/Shader/ShaderExpressionType.inl +++ b/include/Nazara/Shader/ShaderAstTypes.inl @@ -2,18 +2,18 @@ // This file is part of the "Nazara Engine - Shader generator" // For conditions of distribution and use, see copyright notice in Config.hpp -#include +#include #include #include -namespace Nz +namespace Nz::ShaderAst { inline bool IsBasicType(const ShaderExpressionType& type) { return std::visit([&](auto&& arg) { using T = std::decay_t; - if constexpr (std::is_same_v) + if constexpr (std::is_same_v) return true; else if constexpr (std::is_same_v) return false; @@ -25,8 +25,6 @@ namespace Nz inline bool IsMatrixType(const ShaderExpressionType& type) { - using namespace ShaderNodes; - if (!IsBasicType(type)) return false; @@ -58,8 +56,6 @@ namespace Nz inline bool IsSamplerType(const ShaderExpressionType& type) { - using namespace ShaderNodes; - if (!IsBasicType(type)) return false; @@ -94,7 +90,7 @@ namespace Nz return std::visit([&](auto&& arg) { using T = std::decay_t; - if constexpr (std::is_same_v) + if constexpr (std::is_same_v) return false; else if constexpr (std::is_same_v) return true; diff --git a/include/Nazara/Shader/ShaderAstUtils.hpp b/include/Nazara/Shader/ShaderAstUtils.hpp index e78ee657f..3f577ed3e 100644 --- a/include/Nazara/Shader/ShaderAstUtils.hpp +++ b/include/Nazara/Shader/ShaderAstUtils.hpp @@ -10,14 +10,12 @@ #include #include #include -#include +#include #include -namespace Nz +namespace Nz::ShaderAst { - class ShaderAst; - - class NAZARA_SHADER_API ShaderAstValueCategory final : public ShaderAstVisitorExcept + class NAZARA_SHADER_API ShaderAstValueCategory final : public AstExpressionVisitor { public: ShaderAstValueCategory() = default; @@ -25,28 +23,28 @@ namespace Nz ShaderAstValueCategory(ShaderAstValueCategory&&) = delete; ~ShaderAstValueCategory() = default; - ShaderNodes::ExpressionCategory GetExpressionCategory(const ShaderNodes::ExpressionPtr& expression); + ExpressionCategory GetExpressionCategory(Expression& expression); ShaderAstValueCategory& operator=(const ShaderAstValueCategory&) = delete; ShaderAstValueCategory& operator=(ShaderAstValueCategory&&) = delete; private: - using ShaderAstVisitorExcept::Visit; - void Visit(ShaderNodes::AccessMember& node) override; - void Visit(ShaderNodes::AssignOp& node) override; - void Visit(ShaderNodes::BinaryOp& node) override; - void Visit(ShaderNodes::Cast& node) override; - void Visit(ShaderNodes::ConditionalExpression& node) override; - void Visit(ShaderNodes::Constant& node) override; - void Visit(ShaderNodes::Identifier& node) override; - void Visit(ShaderNodes::IntrinsicCall& node) override; - void Visit(ShaderNodes::Sample2D& node) override; - void Visit(ShaderNodes::SwizzleOp& node) override; + using AstExpressionVisitor::Visit; - ShaderNodes::ExpressionCategory m_expressionCategory; + void Visit(AccessMemberExpression& node) override; + void Visit(AssignExpression& node) override; + void Visit(BinaryExpression& node) override; + void Visit(CastExpression& node) override; + void Visit(ConditionalExpression& node) override; + void Visit(ConstantExpression& node) override; + void Visit(IdentifierExpression& node) override; + void Visit(IntrinsicExpression& node) override; + void Visit(SwizzleExpression& node) override; + + ExpressionCategory m_expressionCategory; }; - inline ShaderNodes::ExpressionCategory GetExpressionCategory(const ShaderNodes::ExpressionPtr& expression); + inline ExpressionCategory GetExpressionCategory(Expression& expression); } #include diff --git a/include/Nazara/Shader/ShaderAstUtils.inl b/include/Nazara/Shader/ShaderAstUtils.inl index 852b2e685..dec5ed3a9 100644 --- a/include/Nazara/Shader/ShaderAstUtils.inl +++ b/include/Nazara/Shader/ShaderAstUtils.inl @@ -5,9 +5,9 @@ #include #include -namespace Nz +namespace Nz::ShaderAst { - ShaderNodes::ExpressionCategory GetExpressionCategory(const ShaderNodes::ExpressionPtr& expression) + ExpressionCategory GetExpressionCategory(Expression& expression) { ShaderAstValueCategory visitor; return visitor.GetExpressionCategory(expression); diff --git a/include/Nazara/Shader/ShaderAstValidator.hpp b/include/Nazara/Shader/ShaderAstValidator.hpp index 90ec82c71..00d708d96 100644 --- a/include/Nazara/Shader/ShaderAstValidator.hpp +++ b/include/Nazara/Shader/ShaderAstValidator.hpp @@ -8,66 +8,62 @@ #define NAZARA_SHADERVALIDATOR_HPP #include -#include -#include #include -#include +#include #include -#include +#include -namespace Nz +namespace Nz::ShaderAst { - class NAZARA_SHADER_API ShaderAstValidator : public ShaderAstRecursiveVisitor, public ShaderVarVisitor + class NAZARA_SHADER_API AstValidator : public AstRecursiveVisitor { public: - inline ShaderAstValidator(const ShaderAst& shader); - ShaderAstValidator(const ShaderAstValidator&) = delete; - ShaderAstValidator(ShaderAstValidator&&) = delete; - ~ShaderAstValidator() = default; + inline AstValidator(); + AstValidator(const AstValidator&) = delete; + AstValidator(AstValidator&&) = delete; + ~AstValidator() = default; - bool Validate(std::string* error = nullptr); + bool Validate(StatementPtr& node, std::string* error = nullptr, AstCache* cache = nullptr); private: - const ShaderNodes::ExpressionPtr& MandatoryExpr(const ShaderNodes::ExpressionPtr& node); - const ShaderNodes::NodePtr& MandatoryNode(const ShaderNodes::NodePtr& node); - void TypeMustMatch(const ShaderNodes::ExpressionPtr& left, const ShaderNodes::ExpressionPtr& right); + Expression& MandatoryExpr(ExpressionPtr& node); + Statement& MandatoryStatement(StatementPtr& node); + void TypeMustMatch(ExpressionPtr& left, ExpressionPtr& right); void TypeMustMatch(const ShaderExpressionType& left, const ShaderExpressionType& right); - const ShaderAst::StructMember& CheckField(const std::string& structName, std::size_t* memberIndex, std::size_t remainingMembers); + ShaderExpressionType CheckField(const std::string& structName, const std::string* memberIdentifier, std::size_t remainingMembers); - using ShaderAstRecursiveVisitor::Visit; - void Visit(ShaderNodes::AccessMember& node) override; - void Visit(ShaderNodes::AssignOp& node) override; - void Visit(ShaderNodes::BinaryOp& node) override; - void Visit(ShaderNodes::Branch& node) override; - void Visit(ShaderNodes::Cast& node) override; - void Visit(ShaderNodes::ConditionalExpression& node) override; - void Visit(ShaderNodes::ConditionalStatement& node) override; - void Visit(ShaderNodes::Constant& node) override; - void Visit(ShaderNodes::DeclareVariable& node) override; - void Visit(ShaderNodes::ExpressionStatement& node) override; - void Visit(ShaderNodes::Identifier& node) override; - void Visit(ShaderNodes::IntrinsicCall& node) override; - void Visit(ShaderNodes::ReturnStatement& node) override; - void Visit(ShaderNodes::Sample2D& node) override; - void Visit(ShaderNodes::StatementBlock& node) override; - void Visit(ShaderNodes::SwizzleOp& node) override; + AstCache::Scope& EnterScope(); + void ExitScope(); - using ShaderVarVisitor::Visit; - void Visit(ShaderNodes::BuiltinVariable& var) override; - void Visit(ShaderNodes::InputVariable& var) override; - void Visit(ShaderNodes::LocalVariable& var) override; - void Visit(ShaderNodes::OutputVariable& var) override; - void Visit(ShaderNodes::ParameterVariable& var) override; - void Visit(ShaderNodes::UniformVariable& var) override; + void RegisterExpressionType(Expression& node, ShaderExpressionType expressionType); + void RegisterScope(Node& node); + + void Visit(AccessMemberExpression& node) override; + void Visit(AssignExpression& node) override; + void Visit(BinaryExpression& node) override; + void Visit(CastExpression& node) override; + void Visit(ConditionalExpression& node) override; + void Visit(ConstantExpression& node) override; + void Visit(IdentifierExpression& node) override; + void Visit(IntrinsicExpression& node) override; + void Visit(SwizzleExpression& node) override; + + void Visit(BranchStatement& node) override; + void Visit(ConditionalStatement& node) override; + void Visit(DeclareFunctionStatement& node) override; + void Visit(DeclareStructStatement& node) override; + void Visit(DeclareVariableStatement& node) override; + void Visit(ExpressionStatement& node) override; + void Visit(MultiStatement& node) override; + void Visit(ReturnStatement& node) override; struct Context; - const ShaderAst& m_shader; Context* m_context; }; - NAZARA_SHADER_API bool ValidateShader(const ShaderAst& shader, std::string* error = nullptr); + NAZARA_SHADER_API bool ValidateAst(StatementPtr& node, std::string* error = nullptr, AstCache* cache = nullptr); } #include diff --git a/include/Nazara/Shader/ShaderAstValidator.inl b/include/Nazara/Shader/ShaderAstValidator.inl index eed116766..2020badd4 100644 --- a/include/Nazara/Shader/ShaderAstValidator.inl +++ b/include/Nazara/Shader/ShaderAstValidator.inl @@ -5,10 +5,10 @@ #include #include -namespace Nz +namespace Nz::ShaderAst { - ShaderAstValidator::ShaderAstValidator(const ShaderAst& shader) : - m_shader(shader) + AstValidator::AstValidator() : + m_context(nullptr) { } } diff --git a/include/Nazara/Shader/ShaderAstVisitor.hpp b/include/Nazara/Shader/ShaderAstVisitor.hpp deleted file mode 100644 index 183a58f92..000000000 --- a/include/Nazara/Shader/ShaderAstVisitor.hpp +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright (C) 2020 Jérôme Leclercq -// This file is part of the "Nazara Engine - Shader generator" -// For conditions of distribution and use, see copyright notice in Config.hpp - -#pragma once - -#ifndef NAZARA_SHADERASTVISITOR_HPP -#define NAZARA_SHADERASTVISITOR_HPP - -#include -#include -#include - -namespace Nz -{ - class NAZARA_SHADER_API ShaderAstVisitor - { - public: - ShaderAstVisitor() = default; - ShaderAstVisitor(const ShaderAstVisitor&) = delete; - ShaderAstVisitor(ShaderAstVisitor&&) = delete; - virtual ~ShaderAstVisitor(); - - void Visit(const ShaderNodes::NodePtr& node); - virtual void Visit(ShaderNodes::AccessMember& node) = 0; - virtual void Visit(ShaderNodes::AssignOp& node) = 0; - virtual void Visit(ShaderNodes::BinaryOp& node) = 0; - virtual void Visit(ShaderNodes::Branch& node) = 0; - virtual void Visit(ShaderNodes::Cast& node) = 0; - virtual void Visit(ShaderNodes::ConditionalExpression& node) = 0; - virtual void Visit(ShaderNodes::ConditionalStatement& node) = 0; - virtual void Visit(ShaderNodes::Constant& node) = 0; - virtual void Visit(ShaderNodes::DeclareVariable& node) = 0; - virtual void Visit(ShaderNodes::Discard& node) = 0; - virtual void Visit(ShaderNodes::ExpressionStatement& node) = 0; - virtual void Visit(ShaderNodes::Identifier& node) = 0; - virtual void Visit(ShaderNodes::IntrinsicCall& node) = 0; - virtual void Visit(ShaderNodes::NoOp& node) = 0; - virtual void Visit(ShaderNodes::ReturnStatement& node) = 0; - virtual void Visit(ShaderNodes::Sample2D& node) = 0; - virtual void Visit(ShaderNodes::StatementBlock& node) = 0; - virtual void Visit(ShaderNodes::SwizzleOp& node) = 0; - - ShaderAstVisitor& operator=(const ShaderAstVisitor&) = delete; - ShaderAstVisitor& operator=(ShaderAstVisitor&&) = delete; - }; -} - -#endif diff --git a/include/Nazara/Shader/ShaderAstVisitorExcept.hpp b/include/Nazara/Shader/ShaderAstVisitorExcept.hpp deleted file mode 100644 index 5076284fe..000000000 --- a/include/Nazara/Shader/ShaderAstVisitorExcept.hpp +++ /dev/null @@ -1,41 +0,0 @@ -// Copyright (C) 2020 Jérôme Leclercq -// This file is part of the "Nazara Engine - Shader generator" -// For conditions of distribution and use, see copyright notice in Config.hpp - -#pragma once - -#ifndef NAZARA_SHADERASTVISITOREXCEPT_HPP -#define NAZARA_SHADERASTVISITOREXCEPT_HPP - -#include -#include -#include - -namespace Nz -{ - class NAZARA_SHADER_API ShaderAstVisitorExcept : public ShaderAstVisitor - { - public: - using ShaderAstVisitor::Visit; - void Visit(ShaderNodes::AccessMember& node) override; - void Visit(ShaderNodes::AssignOp& node) override; - void Visit(ShaderNodes::BinaryOp& node) override; - void Visit(ShaderNodes::Branch& node) override; - void Visit(ShaderNodes::Cast& node) override; - void Visit(ShaderNodes::ConditionalExpression& node) override; - void Visit(ShaderNodes::ConditionalStatement& node) override; - void Visit(ShaderNodes::Constant& node) override; - void Visit(ShaderNodes::DeclareVariable& node) override; - void Visit(ShaderNodes::Discard& node) override; - void Visit(ShaderNodes::ExpressionStatement& node) override; - void Visit(ShaderNodes::Identifier& node) override; - void Visit(ShaderNodes::IntrinsicCall& node) override; - void Visit(ShaderNodes::NoOp& node) override; - void Visit(ShaderNodes::ReturnStatement& node) override; - void Visit(ShaderNodes::Sample2D& node) override; - void Visit(ShaderNodes::StatementBlock& node) override; - void Visit(ShaderNodes::SwizzleOp& node) override; - }; -} - -#endif diff --git a/include/Nazara/Shader/ShaderBuilder.hpp b/include/Nazara/Shader/ShaderBuilder.hpp index 61dab693e..6a18b25ed 100644 --- a/include/Nazara/Shader/ShaderBuilder.hpp +++ b/include/Nazara/Shader/ShaderBuilder.hpp @@ -13,70 +13,60 @@ namespace Nz::ShaderBuilder { - template - struct AssignOpBuilder + namespace Impl { - constexpr AssignOpBuilder() = default; + struct Binary + { + inline std::unique_ptr operator()(ShaderAst::BinaryType op, ShaderAst::ExpressionPtr left, ShaderAst::ExpressionPtr right) const; + }; - std::shared_ptr operator()(const ShaderNodes::ExpressionPtr& left, const ShaderNodes::ExpressionPtr& right) const; - }; + struct Branch + { + inline std::unique_ptr operator()(ShaderAst::ExpressionPtr condition, ShaderAst::StatementPtr truePath, ShaderAst::StatementPtr falsePath = nullptr) const; + inline std::unique_ptr operator()(std::vector condStatements, ShaderAst::StatementPtr elseStatement = nullptr) const; + }; - template - struct BinOpBuilder - { - constexpr BinOpBuilder() = default; + struct Constant + { + inline std::unique_ptr operator()(ShaderConstantValue value) const; + }; - std::shared_ptr operator()(const ShaderNodes::ExpressionPtr& left, const ShaderNodes::ExpressionPtr& right) const; - }; + struct DeclareFunction + { + inline std::unique_ptr operator()(std::string name, std::vector parameters, std::vector statements, ShaderAst::ShaderExpressionType returnType = ShaderAst::BasicType::Void) const; + }; - struct BuiltinBuilder - { - constexpr BuiltinBuilder() = default; + struct DeclareVariable + { + inline std::unique_ptr operator()(std::string name, ShaderAst::ShaderExpressionType type, ShaderAst::ExpressionPtr initialValue = nullptr) const; + }; - inline std::shared_ptr operator()(ShaderNodes::BuiltinEntry builtin) const; - }; + struct Identifier + { + inline std::unique_ptr operator()(std::string name) const; + }; - template - struct GenBuilder - { - constexpr GenBuilder() = default; + struct Return + { + inline std::unique_ptr operator()(ShaderAst::ExpressionPtr expr = nullptr) const; + }; - template std::shared_ptr operator()(Args&&... args) const; - }; + template + struct NoParam + { + std::unique_ptr operator()() const; + }; + } - constexpr GenBuilder AccessMember; - constexpr BinOpBuilder Add; - constexpr AssignOpBuilder Assign; - constexpr BuiltinBuilder Builtin; - constexpr GenBuilder Block; - constexpr GenBuilder Branch; - constexpr GenBuilder ConditionalExpression; - constexpr GenBuilder ConditionalStatement; - constexpr GenBuilder Constant; - constexpr GenBuilder DeclareVariable; - constexpr GenBuilder Discard; - constexpr BinOpBuilder Division; - constexpr BinOpBuilder Equal; - constexpr BinOpBuilder GreaterThan; - constexpr BinOpBuilder GreaterThanOrEqual; - constexpr BinOpBuilder LessThan; - constexpr BinOpBuilder LessThanOrEqual; - constexpr BinOpBuilder NotEqual; - constexpr GenBuilder ExprStatement; - constexpr GenBuilder Identifier; - constexpr GenBuilder IntrinsicCall; - constexpr GenBuilder Input; - constexpr GenBuilder Local; - constexpr BinOpBuilder Multiply; - constexpr GenBuilder Output; - constexpr GenBuilder Parameter; - constexpr GenBuilder Sample2D; - constexpr GenBuilder StatementBlock; - constexpr GenBuilder Swizzle; - constexpr BinOpBuilder Subtract; - constexpr GenBuilder Uniform; - - template std::shared_ptr Cast(Args&&... args); + constexpr Impl::Binary Binary; + constexpr Impl::Branch Branch; + constexpr Impl::Constant Constant; + constexpr Impl::DeclareFunction DeclareFunction; + constexpr Impl::DeclareVariable DeclareVariable; + constexpr Impl::NoParam Discard; + constexpr Impl::Identifier Identifier; + constexpr Impl::NoParam NoOp; + constexpr Impl::Return Return; } #include diff --git a/include/Nazara/Shader/ShaderBuilder.inl b/include/Nazara/Shader/ShaderBuilder.inl index c3221c84f..ef7e89849 100644 --- a/include/Nazara/Shader/ShaderBuilder.inl +++ b/include/Nazara/Shader/ShaderBuilder.inl @@ -7,45 +7,87 @@ namespace Nz::ShaderBuilder { + inline std::unique_ptr Impl::Binary::operator()(ShaderAst::BinaryType op, ShaderAst::ExpressionPtr left, ShaderAst::ExpressionPtr right) const + { + auto constantNode = std::make_unique(); + constantNode->op = op; + constantNode->left = std::move(left); + constantNode->right = std::move(right); + + return constantNode; + } + + inline std::unique_ptr Impl::Branch::operator()(ShaderAst::ExpressionPtr condition, ShaderAst::StatementPtr truePath, ShaderAst::StatementPtr falsePath) const + { + auto branchNode = std::make_unique(); + + auto& condStatement = branchNode->condStatements.emplace_back(); + condStatement.condition = std::move(condition); + condStatement.statement = std::move(truePath); + + branchNode->elseStatement = std::move(falsePath); + + return branchNode; + } + + inline std::unique_ptr Impl::Branch::operator()(std::vector condStatements, ShaderAst::StatementPtr elseStatement) const + { + auto branchNode = std::make_unique(); + branchNode->condStatements = std::move(condStatements); + branchNode->elseStatement = std::move(elseStatement); + + return branchNode; + } + + inline std::unique_ptr Impl::Constant::operator()(ShaderConstantValue value) const + { + auto constantNode = std::make_unique(); + constantNode->value = std::move(value); + + return constantNode; + } + + inline std::unique_ptr Impl::DeclareFunction::operator()(std::string name, std::vector parameters, std::vector statements, ShaderAst::ShaderExpressionType returnType) const + { + auto declareFunctionNode = std::make_unique(); + declareFunctionNode->name = std::move(name); + declareFunctionNode->parameters = std::move(parameters); + declareFunctionNode->returnType = std::move(returnType); + declareFunctionNode->statements = std::move(statements); + + return declareFunctionNode; + } + + inline std::unique_ptr Nz::ShaderBuilder::Impl::DeclareVariable::operator()(std::string name, ShaderAst::ShaderExpressionType type, ShaderAst::ExpressionPtr initialValue) const + { + auto declareVariableNode = std::make_unique(); + declareVariableNode->varName = std::move(name); + declareVariableNode->varType = std::move(type); + declareVariableNode->initialExpression = std::move(initialValue); + + return declareVariableNode; + } + + inline std::unique_ptr Impl::Identifier::operator()(std::string name) const + { + auto identifierNode = std::make_unique(); + identifierNode->identifier = std::move(name); + + return identifierNode; + } + + inline std::unique_ptr Impl::Return::operator()(ShaderAst::ExpressionPtr expr) const + { + auto returnNode = std::make_unique(); + returnNode->returnExpr = std::move(expr); + + return returnNode; + } + template - template - std::shared_ptr GenBuilder::operator()(Args&&... args) const + std::unique_ptr Impl::NoParam::operator()() const { - return T::Build(std::forward(args)...); - } - - template - std::shared_ptr AssignOpBuilder::operator()(const ShaderNodes::ExpressionPtr& left, const ShaderNodes::ExpressionPtr& right) const - { - return ShaderNodes::AssignOp::Build(op, left, right); - } - - template - std::shared_ptr BinOpBuilder::operator()(const ShaderNodes::ExpressionPtr& left, const ShaderNodes::ExpressionPtr& right) const - { - return ShaderNodes::BinaryOp::Build(op, left, right); - } - - inline std::shared_ptr BuiltinBuilder::operator()(ShaderNodes::BuiltinEntry builtin) const - { - ShaderNodes::BasicType exprType = ShaderNodes::BasicType::Void; - - switch (builtin) - { - case ShaderNodes::BuiltinEntry::VertexPosition: - exprType = ShaderNodes::BasicType::Float4; - break; - } - - NazaraAssert(exprType != ShaderNodes::BasicType::Void, "Unhandled builtin"); - - return ShaderNodes::BuiltinVariable::Build(builtin, exprType); - } - - template - std::shared_ptr Cast(Args&&... args) - { - return ShaderNodes::Cast::Build(Type, std::forward(args)...); + return std::make_unique(); } } diff --git a/include/Nazara/Shader/ShaderEnums.hpp b/include/Nazara/Shader/ShaderEnums.hpp index 87f2f4cd7..f5182c243 100644 --- a/include/Nazara/Shader/ShaderEnums.hpp +++ b/include/Nazara/Shader/ShaderEnums.hpp @@ -9,7 +9,7 @@ #include -namespace Nz::ShaderNodes +namespace Nz::ShaderAst { enum class AssignType { @@ -77,35 +77,9 @@ namespace Nz::ShaderNodes { None = -1, - AccessMember, - AssignOp, - BinaryOp, - Branch, - Cast, - Constant, - ConditionalExpression, - ConditionalStatement, - DeclareVariable, - Discard, - ExpressionStatement, - Identifier, - IntrinsicCall, - NoOp, - ReturnStatement, - Sample2D, - SwizzleOp, - StatementBlock, - - Max = StatementBlock - }; - - enum class SsaInstruction - { - OpAdd, - OpDiv, - OpMul, - OpSub, - OpSample +#define NAZARA_SHADERAST_NODE(Node) Node, +#define NAZARA_SHADERAST_STATEMENT_LAST(Node) Node, Max = Node +#include }; enum class SwizzleComponent @@ -127,6 +101,11 @@ namespace Nz::ShaderNodes ParameterVariable, UniformVariable }; + + inline std::size_t GetComponentCount(BasicType type); + inline BasicType GetComponentType(BasicType type); } +#include + #endif // NAZARA_SHADER_ENUMS_HPP diff --git a/include/Nazara/Shader/ShaderEnums.inl b/include/Nazara/Shader/ShaderEnums.inl new file mode 100644 index 000000000..fbd01ad49 --- /dev/null +++ b/include/Nazara/Shader/ShaderEnums.inl @@ -0,0 +1,57 @@ +// Copyright (C) 2020 Jérôme Leclercq +// This file is part of the "Nazara Engine - Shader generator" +// For conditions of distribution and use, see copyright notice in Config.hpp + +#include +#include + +namespace Nz::ShaderAst +{ + inline std::size_t GetComponentCount(BasicType type) + { + switch (type) + { + case BasicType::Float2: + case BasicType::Int2: + return 2; + + case BasicType::Float3: + case BasicType::Int3: + return 3; + + case BasicType::Float4: + case BasicType::Int4: + return 4; + + case BasicType::Mat4x4: + return 4; + + default: + return 1; + } + } + + inline BasicType GetComponentType(BasicType type) + { + switch (type) + { + case BasicType::Float2: + case BasicType::Float3: + case BasicType::Float4: + return BasicType::Float1; + + case BasicType::Int2: + case BasicType::Int3: + case BasicType::Int4: + return BasicType::Int1; + + case BasicType::Mat4x4: + return BasicType::Float4; + + default: + return type; + } + } +} + +#include diff --git a/include/Nazara/Shader/ShaderLangParser.hpp b/include/Nazara/Shader/ShaderLangParser.hpp index 6b3d296a3..0588ecc44 100644 --- a/include/Nazara/Shader/ShaderLangParser.hpp +++ b/include/Nazara/Shader/ShaderLangParser.hpp @@ -10,7 +10,7 @@ #include #include #include -#include +#include namespace Nz::ShaderLang { @@ -44,7 +44,7 @@ namespace Nz::ShaderLang inline Parser(); ~Parser() = default; - ShaderAst Parse(const std::vector& tokens); + ShaderAst::StatementPtr Parse(const std::vector& tokens); private: // Flow control @@ -54,29 +54,30 @@ namespace Nz::ShaderLang const Token& PeekNext(); // Statements - ShaderNodes::StatementPtr ParseFunctionBody(); - void ParseFunctionDeclaration(); - ShaderAst::FunctionParameter ParseFunctionParameter(); - ShaderNodes::StatementPtr ParseReturnStatement(); - ShaderNodes::StatementPtr ParseStatement(); - ShaderNodes::StatementPtr ParseStatementList(); + std::vector ParseFunctionBody(); + ShaderAst::StatementPtr ParseFunctionDeclaration(); + ShaderAst::DeclareFunctionStatement::Parameter ParseFunctionParameter(); + ShaderAst::StatementPtr ParseReturnStatement(); + ShaderAst::StatementPtr ParseStatement(); + std::vector ParseStatementList(); + ShaderAst::StatementPtr ParseVariableDeclaration(); // Expressions - ShaderNodes::ExpressionPtr ParseBinOpRhs(int exprPrecedence, ShaderNodes::ExpressionPtr lhs); - ShaderNodes::ExpressionPtr ParseExpression(); - ShaderNodes::ExpressionPtr ParseIdentifier(); - ShaderNodes::ExpressionPtr ParseIntegerExpression(); - ShaderNodes::ExpressionPtr ParseParenthesisExpression(); - ShaderNodes::ExpressionPtr ParsePrimaryExpression(); + ShaderAst::ExpressionPtr ParseBinOpRhs(int exprPrecedence, ShaderAst::ExpressionPtr lhs); + ShaderAst::ExpressionPtr ParseExpression(); + ShaderAst::ExpressionPtr ParseIdentifier(); + ShaderAst::ExpressionPtr ParseIntegerExpression(); + ShaderAst::ExpressionPtr ParseParenthesisExpression(); + ShaderAst::ExpressionPtr ParsePrimaryExpression(); std::string ParseIdentifierAsName(); - ShaderExpressionType ParseIdentifierAsType(); + ShaderAst::ShaderExpressionType ParseIdentifierAsType(); static int GetTokenPrecedence(TokenType token); struct Context { - ShaderAst result; + std::unique_ptr root; std::size_t tokenCount; std::size_t tokenIndex = 0; const Token* tokens; diff --git a/include/Nazara/Shader/ShaderLangTokenList.hpp b/include/Nazara/Shader/ShaderLangTokenList.hpp index caaa2964d..d8eb56139 100644 --- a/include/Nazara/Shader/ShaderLangTokenList.hpp +++ b/include/Nazara/Shader/ShaderLangTokenList.hpp @@ -6,10 +6,11 @@ #error You must define NAZARA_SHADERLANG_TOKEN before including this file #endif -#ifndef NAZARA_SHADERLANG_TOKENT_LAST +#ifndef NAZARA_SHADERLANG_TOKEN_LAST #define NAZARA_SHADERLANG_TOKEN_LAST(X) NAZARA_SHADERLANG_TOKEN(X) #endif +NAZARA_SHADERLANG_TOKEN(Assign) NAZARA_SHADERLANG_TOKEN(BoolFalse) NAZARA_SHADERLANG_TOKEN(BoolTrue) NAZARA_SHADERLANG_TOKEN(ClosingParenthesis) @@ -24,6 +25,7 @@ NAZARA_SHADERLANG_TOKEN(FunctionDeclaration) NAZARA_SHADERLANG_TOKEN(FunctionReturn) NAZARA_SHADERLANG_TOKEN(IntegerValue) NAZARA_SHADERLANG_TOKEN(Identifier) +NAZARA_SHADERLANG_TOKEN(Let) NAZARA_SHADERLANG_TOKEN(Multiply) NAZARA_SHADERLANG_TOKEN(Minus) NAZARA_SHADERLANG_TOKEN(Plus) diff --git a/include/Nazara/Shader/ShaderNodes.hpp b/include/Nazara/Shader/ShaderNodes.hpp index 0d1d57183..7aeef561c 100644 --- a/include/Nazara/Shader/ShaderNodes.hpp +++ b/include/Nazara/Shader/ShaderNodes.hpp @@ -14,308 +14,245 @@ #include #include #include -#include -#include +#include #include #include #include #include -namespace Nz +namespace Nz::ShaderAst { - class ShaderAstVisitor; + class AstExpressionVisitor; + class AstStatementVisitor; - namespace ShaderNodes + struct NAZARA_SHADER_API Node { - class Node; + Node() = default; + Node(const Node&) = delete; + Node(Node&&) noexcept = default; + virtual ~Node(); - using NodePtr = std::shared_ptr; + virtual NodeType GetType() const = 0; - class NAZARA_SHADER_API Node + Node& operator=(const Node&) = delete; + Node& operator=(Node&&) noexcept = default; + }; + + // Expressions + + struct Expression; + + using ExpressionPtr = std::unique_ptr; + + struct NAZARA_SHADER_API Expression : Node + { + Expression() = default; + Expression(const Expression&) = delete; + Expression(Expression&&) noexcept = default; + ~Expression() = default; + + virtual void Visit(AstExpressionVisitor& visitor) = 0; + + Expression& operator=(const Expression&) = delete; + Expression& operator=(Expression&&) noexcept = default; + }; + + struct NAZARA_SHADER_API AccessMemberExpression : public Expression + { + NodeType GetType() const override; + void Visit(AstExpressionVisitor& visitor) override; + + ExpressionPtr structExpr; + std::vector memberIdentifiers; + }; + + struct NAZARA_SHADER_API AssignExpression : public Expression + { + NodeType GetType() const override; + void Visit(AstExpressionVisitor& visitor) override; + + AssignType op; + ExpressionPtr left; + ExpressionPtr right; + }; + + struct NAZARA_SHADER_API BinaryExpression : public Expression + { + NodeType GetType() const override; + void Visit(AstExpressionVisitor& visitor) override; + + BinaryType op; + ExpressionPtr left; + ExpressionPtr right; + }; + + struct NAZARA_SHADER_API CastExpression : public Expression + { + NodeType GetType() const override; + void Visit(AstExpressionVisitor& visitor) override; + + BasicType targetType; + std::array expressions; + }; + + struct NAZARA_SHADER_API ConditionalExpression : public Expression + { + NodeType GetType() const override; + void Visit(AstExpressionVisitor& visitor) override; + + std::string conditionName; + ExpressionPtr falsePath; + ExpressionPtr truePath; + }; + + struct NAZARA_SHADER_API ConstantExpression : public Expression + { + NodeType GetType() const override; + void Visit(AstExpressionVisitor& visitor) override; + + ShaderConstantValue value; + }; + + struct NAZARA_SHADER_API IdentifierExpression : public Expression + { + NodeType GetType() const override; + void Visit(AstExpressionVisitor& visitor) override; + + std::string identifier; + }; + + struct NAZARA_SHADER_API IntrinsicExpression : public Expression + { + NodeType GetType() const override; + void Visit(AstExpressionVisitor& visitor) override; + + IntrinsicType intrinsic; + std::vector parameters; + }; + + struct NAZARA_SHADER_API SwizzleExpression : public Expression + { + NodeType GetType() const override; + void Visit(AstExpressionVisitor& visitor) override; + + std::array components; + std::size_t componentCount; + ExpressionPtr expression; + }; + + // Statements + + struct Statement; + + using StatementPtr = std::unique_ptr; + + struct NAZARA_SHADER_API Statement : Node + { + Statement() = default; + Statement(const Statement&) = delete; + Statement(Statement&&) noexcept = default; + ~Statement() = default; + + virtual void Visit(AstStatementVisitor& visitor) = 0; + + Statement& operator=(const Statement&) = delete; + Statement& operator=(Statement&&) noexcept = default; + }; + + struct NAZARA_SHADER_API BranchStatement : public Statement + { + NodeType GetType() const override; + void Visit(AstStatementVisitor& visitor) override; + + struct ConditionalStatement { - public: - virtual ~Node(); - - inline NodeType GetType() const; - inline bool IsStatement() const; - - virtual void Visit(ShaderAstVisitor& visitor) = 0; - - static inline unsigned int GetComponentCount(BasicType type); - static inline BasicType GetComponentType(BasicType type); - - protected: - inline Node(NodeType type, bool isStatement); - - private: - NodeType m_type; - bool m_isStatement; - }; - - class Expression; - - using ExpressionPtr = std::shared_ptr; - - class NAZARA_SHADER_API Expression : public Node, public std::enable_shared_from_this - { - public: - inline Expression(NodeType type); - - virtual ShaderExpressionType GetExpressionType() const = 0; + ExpressionPtr condition; + StatementPtr statement; }; - class Statement; + std::vector condStatements; + StatementPtr elseStatement; + }; - using StatementPtr = std::shared_ptr; + struct NAZARA_SHADER_API ConditionalStatement : Statement + { + NodeType GetType() const override; + void Visit(AstStatementVisitor& visitor) override; - class NAZARA_SHADER_API Statement : public Node, public std::enable_shared_from_this + std::string conditionName; + StatementPtr statement; + }; + + struct NAZARA_SHADER_API DeclareFunctionStatement : Statement + { + NodeType GetType() const override; + void Visit(AstStatementVisitor& visitor) override; + + struct Parameter { - public: - inline Statement(NodeType type); + std::string name; + ShaderExpressionType type; }; - struct NAZARA_SHADER_API ExpressionStatement : public Statement - { - inline ExpressionStatement(); - - void Visit(ShaderAstVisitor& visitor) override; - - ExpressionPtr expression; - - static inline std::shared_ptr Build(ExpressionPtr expr); - }; - - ////////////////////////////////////////////////////////////////////////// - - struct NAZARA_SHADER_API ConditionalStatement : public Statement - { - inline ConditionalStatement(); - - void Visit(ShaderAstVisitor& visitor) override; - - std::string conditionName; - StatementPtr statement; - - static inline std::shared_ptr Build(std::string condition, StatementPtr statementPtr); - }; - - struct NAZARA_SHADER_API StatementBlock : public Statement - { - inline StatementBlock(); - - void Visit(ShaderAstVisitor& visitor) override; - - std::vector statements; - - static inline std::shared_ptr Build(std::vector statements); - template static std::shared_ptr Build(Args&&... args); - }; - - struct NAZARA_SHADER_API DeclareVariable : public Statement - { - inline DeclareVariable(); - - void Visit(ShaderAstVisitor& visitor) override; - - ExpressionPtr expression; - VariablePtr variable; - - static inline std::shared_ptr Build(VariablePtr variable, ExpressionPtr expression = nullptr); - }; - - struct NAZARA_SHADER_API Discard : public Statement - { - inline Discard(); - - void Visit(ShaderAstVisitor& visitor) override; - - static inline std::shared_ptr Build(); - }; - - struct NAZARA_SHADER_API Identifier : public Expression - { - inline Identifier(); - - ShaderExpressionType GetExpressionType() const override; - void Visit(ShaderAstVisitor& visitor) override; - - VariablePtr var; - - static inline std::shared_ptr Build(VariablePtr variable); - }; - - struct NAZARA_SHADER_API AccessMember : public Expression - { - inline AccessMember(); - - ShaderExpressionType GetExpressionType() const override; - void Visit(ShaderAstVisitor& visitor) override; - - ExpressionPtr structExpr; - ShaderExpressionType exprType; - std::vector memberIndices; - - static inline std::shared_ptr Build(ExpressionPtr structExpr, std::size_t memberIndex, ShaderExpressionType exprType); - static inline std::shared_ptr Build(ExpressionPtr structExpr, std::vector memberIndices, ShaderExpressionType exprType); - }; - - struct NAZARA_SHADER_API NoOp : public Statement - { - inline NoOp(); - - void Visit(ShaderAstVisitor& visitor) override; - - static inline std::shared_ptr Build(); - }; - - struct NAZARA_SHADER_API ReturnStatement : public Statement - { - inline ReturnStatement(); - - void Visit(ShaderAstVisitor& visitor) override; - - ExpressionPtr returnExpr; - - static inline std::shared_ptr Build(ExpressionPtr expr = nullptr); - }; - - ////////////////////////////////////////////////////////////////////////// - - struct NAZARA_SHADER_API AssignOp : public Expression - { - inline AssignOp(); - - ShaderExpressionType GetExpressionType() const override; - void Visit(ShaderAstVisitor& visitor) override; - - AssignType op; - ExpressionPtr left; - ExpressionPtr right; - - static inline std::shared_ptr Build(AssignType op, ExpressionPtr left, ExpressionPtr right); - }; - - struct NAZARA_SHADER_API BinaryOp : public Expression - { - inline BinaryOp(); - - ShaderExpressionType GetExpressionType() const override; - void Visit(ShaderAstVisitor& visitor) override; - - BinaryType op; - ExpressionPtr left; - ExpressionPtr right; - - static inline std::shared_ptr Build(BinaryType op, ExpressionPtr left, ExpressionPtr right); - }; - - struct NAZARA_SHADER_API Branch : public Statement - { - struct ConditionalStatement; - - inline Branch(); - - void Visit(ShaderAstVisitor& visitor) override; - - std::vector condStatements; - StatementPtr elseStatement; - - struct ConditionalStatement - { - ExpressionPtr condition; - StatementPtr statement; - }; - - static inline std::shared_ptr Build(ExpressionPtr condition, StatementPtr trueStatement, StatementPtr falseStatement = nullptr); - static inline std::shared_ptr Build(std::vector statements, StatementPtr elseStatement = nullptr); - }; - - struct NAZARA_SHADER_API Cast : public Expression - { - inline Cast(); - - ShaderExpressionType GetExpressionType() const override; - void Visit(ShaderAstVisitor& visitor) override; - - BasicType exprType; - std::array expressions; - - static inline std::shared_ptr Build(BasicType castTo, ExpressionPtr first, ExpressionPtr second = nullptr, ExpressionPtr third = nullptr, ExpressionPtr fourth = nullptr); - static inline std::shared_ptr Build(BasicType castTo, ExpressionPtr* expressions, std::size_t expressionCount); - }; - - struct NAZARA_SHADER_API ConditionalExpression : public Expression - { - inline ConditionalExpression(); - - ShaderExpressionType GetExpressionType() const override; - void Visit(ShaderAstVisitor& visitor) override; - - std::string conditionName; - ExpressionPtr falsePath; - ExpressionPtr truePath; - - static inline std::shared_ptr Build(std::string condition, ExpressionPtr truePath, ExpressionPtr falsePath); - }; - - struct NAZARA_SHADER_API Constant : public Expression - { - inline Constant(); - - ShaderExpressionType GetExpressionType() const override; - void Visit(ShaderAstVisitor& visitor) override; - - ShaderConstantValue value; - - template static std::shared_ptr Build(const T& value); - }; - - struct NAZARA_SHADER_API SwizzleOp : public Expression - { - inline SwizzleOp(); - - ShaderExpressionType GetExpressionType() const override; - void Visit(ShaderAstVisitor& visitor) override; - - std::array components; - std::size_t componentCount; - ExpressionPtr expression; - - static inline std::shared_ptr Build(ExpressionPtr expressionPtr, SwizzleComponent swizzleComponent); - static inline std::shared_ptr Build(ExpressionPtr expressionPtr, std::initializer_list swizzleComponents); - static inline std::shared_ptr Build(ExpressionPtr expressionPtr, const SwizzleComponent* components, std::size_t componentCount); - }; - - ////////////////////////////////////////////////////////////////////////// - - struct NAZARA_SHADER_API Sample2D : public Expression - { - inline Sample2D(); - - ShaderExpressionType GetExpressionType() const override; - void Visit(ShaderAstVisitor& visitor) override; - - ExpressionPtr sampler; - ExpressionPtr coordinates; - - static inline std::shared_ptr Build(ExpressionPtr samplerPtr, ExpressionPtr coordinatesPtr); - }; - - ////////////////////////////////////////////////////////////////////////// - - struct NAZARA_SHADER_API IntrinsicCall : public Expression - { - inline IntrinsicCall(); - - ShaderExpressionType GetExpressionType() const override; - void Visit(ShaderAstVisitor& visitor) override; - - IntrinsicType intrinsic; - std::vector parameters; - - static inline std::shared_ptr Build(IntrinsicType intrinsic, std::vector parameters); - }; - } + std::string name; + std::vector parameters; + std::vector statements; + ShaderExpressionType returnType = BasicType::Void; + }; + + struct NAZARA_SHADER_API DeclareStructStatement : Statement + { + NodeType GetType() const override; + void Visit(AstStatementVisitor& visitor) override; + + StructDescription description; + }; + + struct NAZARA_SHADER_API DeclareVariableStatement : Statement + { + NodeType GetType() const override; + void Visit(AstStatementVisitor& visitor) override; + + std::string varName; + ExpressionPtr initialExpression; + ShaderExpressionType varType; + }; + + struct NAZARA_SHADER_API DiscardStatement : Statement + { + NodeType GetType() const override; + void Visit(AstStatementVisitor& visitor) override; + }; + + struct NAZARA_SHADER_API ExpressionStatement : Statement + { + NodeType GetType() const override; + void Visit(AstStatementVisitor& visitor) override; + + ExpressionPtr expression; + }; + + struct NAZARA_SHADER_API MultiStatement : Statement + { + NodeType GetType() const override; + void Visit(AstStatementVisitor& visitor) override; + + std::vector statements; + }; + + struct NAZARA_SHADER_API NoOpStatement : Statement + { + NodeType GetType() const override; + void Visit(AstStatementVisitor& visitor) override; + }; + + struct NAZARA_SHADER_API ReturnStatement : Statement + { + NodeType GetType() const override; + void Visit(AstStatementVisitor& visitor) override; + + ExpressionPtr returnExpr; + }; } #include diff --git a/include/Nazara/Shader/ShaderNodes.inl b/include/Nazara/Shader/ShaderNodes.inl index c1e2b22c3..6c702a06e 100644 --- a/include/Nazara/Shader/ShaderNodes.inl +++ b/include/Nazara/Shader/ShaderNodes.inl @@ -5,394 +5,8 @@ #include #include -namespace Nz::ShaderNodes +namespace Nz::ShaderAst { - inline Node::Node(NodeType type, bool isStatement) : - m_type(type), - m_isStatement(isStatement) - { - } - - inline NodeType ShaderNodes::Node::GetType() const - { - return m_type; - } - - inline bool Node::IsStatement() const - { - return m_isStatement; - } - - inline unsigned int Node::GetComponentCount(BasicType type) - { - switch (type) - { - case BasicType::Float2: - case BasicType::Int2: - return 2; - - case BasicType::Float3: - case BasicType::Int3: - return 3; - - case BasicType::Float4: - case BasicType::Int4: - return 4; - - case BasicType::Mat4x4: - return 4; - - default: - return 1; - } - } - - inline BasicType Node::GetComponentType(BasicType type) - { - switch (type) - { - case BasicType::Float2: - case BasicType::Float3: - case BasicType::Float4: - return BasicType::Float1; - - case BasicType::Int2: - case BasicType::Int3: - case BasicType::Int4: - return BasicType::Int1; - - case BasicType::Mat4x4: - return BasicType::Float4; - - default: - return type; - } - } - - - inline Expression::Expression(NodeType type) : - Node(type, false) - { - } - - inline Statement::Statement(NodeType type) : - Node(type, true) - { - } - - - - inline ExpressionStatement::ExpressionStatement() : - Statement(NodeType::ExpressionStatement) - { - } - - inline std::shared_ptr ExpressionStatement::Build(ExpressionPtr expr) - { - auto node = std::make_shared(); - node->expression = std::move(expr); - - return node; - } - - inline ConditionalStatement::ConditionalStatement() : - Statement(NodeType::ConditionalStatement) - { - } - - inline std::shared_ptr ConditionalStatement::Build(std::string condition, StatementPtr statementPtr) - { - auto node = std::make_shared(); - node->conditionName = std::move(condition); - node->statement = std::move(statementPtr); - - return node; - } - - - inline StatementBlock::StatementBlock() : - Statement(NodeType::StatementBlock) - { - } - - inline std::shared_ptr StatementBlock::Build(std::vector statements) - { - auto node = std::make_shared(); - node->statements = std::move(statements); - - return node; - } - - template - std::shared_ptr StatementBlock::Build(Args&&... args) - { - auto node = std::make_shared(); - node->statements = std::vector({ std::forward(args)... }); - - return node; - } - - - inline DeclareVariable::DeclareVariable() : - Statement(NodeType::DeclareVariable) - { - } - - inline std::shared_ptr DeclareVariable::Build(VariablePtr variable, ExpressionPtr expression) - { - auto node = std::make_shared(); - node->expression = std::move(expression); - node->variable = std::move(variable); - - return node; - } - - - inline Discard::Discard() : - Statement(NodeType::Discard) - { - } - - inline std::shared_ptr Discard::Build() - { - return std::make_shared(); - } - - - inline Identifier::Identifier() : - Expression(NodeType::Identifier) - { - } - - inline std::shared_ptr Identifier::Build(VariablePtr variable) - { - auto node = std::make_shared(); - node->var = std::move(variable); - - return node; - } - - - inline AccessMember::AccessMember() : - Expression(NodeType::AccessMember) - { - } - - inline std::shared_ptr AccessMember::Build(ExpressionPtr structExpr, std::size_t memberIndex, ShaderExpressionType exprType) - { - return Build(std::move(structExpr), std::vector{ memberIndex }, exprType); - } - - inline std::shared_ptr AccessMember::Build(ExpressionPtr structExpr, std::vector memberIndices, ShaderExpressionType exprType) - { - auto node = std::make_shared(); - node->exprType = std::move(exprType); - node->memberIndices = std::move(memberIndices); - node->structExpr = std::move(structExpr); - - return node; - } - - - inline NoOp::NoOp() : - Statement(NodeType::NoOp) - { - } - - inline std::shared_ptr NoOp::Build() - { - return std::make_shared(); - } - - - inline ReturnStatement::ReturnStatement() : - Statement(NodeType::ReturnStatement) - { - } - - inline std::shared_ptr ShaderNodes::ReturnStatement::Build(ExpressionPtr expr) - { - auto node = std::make_shared(); - node->returnExpr = std::move(expr); - - return node; - } - - - inline AssignOp::AssignOp() : - Expression(NodeType::AssignOp) - { - } - - inline std::shared_ptr AssignOp::Build(AssignType op, ExpressionPtr left, ExpressionPtr right) - { - auto node = std::make_shared(); - node->op = op; - node->left = std::move(left); - node->right = std::move(right); - - return node; - } - - - inline BinaryOp::BinaryOp() : - Expression(NodeType::BinaryOp) - { - } - - inline std::shared_ptr BinaryOp::Build(BinaryType op, ExpressionPtr left, ExpressionPtr right) - { - auto node = std::make_shared(); - node->op = op; - node->left = std::move(left); - node->right = std::move(right); - - return node; - } - - - inline Branch::Branch() : - Statement(NodeType::Branch) - { - } - - inline std::shared_ptr Branch::Build(ExpressionPtr condition, StatementPtr trueStatement, StatementPtr falseStatement) - { - auto node = std::make_shared(); - node->condStatements.emplace_back(ConditionalStatement{ std::move(condition), std::move(trueStatement) }); - node->elseStatement = std::move(falseStatement); - - return node; - } - - inline std::shared_ptr Branch::Build(std::vector statements, StatementPtr elseStatement) - { - auto node = std::make_shared(); - node->condStatements = std::move(statements); - node->elseStatement = std::move(elseStatement); - - return node; - } - - - inline Cast::Cast() : - Expression(NodeType::Cast) - { - } - - inline std::shared_ptr Cast::Build(BasicType castTo, ExpressionPtr first, ExpressionPtr second, ExpressionPtr third, ExpressionPtr fourth) - { - auto node = std::make_shared(); - node->exprType = castTo; - node->expressions = { {first, second, third, fourth} }; - - return node; - } - - inline std::shared_ptr Cast::Build(BasicType castTo, ExpressionPtr* Expressions, std::size_t expressionCount) - { - auto node = std::make_shared(); - node->exprType = castTo; - for (std::size_t i = 0; i < expressionCount; ++i) - node->expressions[i] = Expressions[i]; - - return node; - } - - inline ConditionalExpression::ConditionalExpression() : - Expression(NodeType::ConditionalExpression) - { - } - - inline std::shared_ptr ShaderNodes::ConditionalExpression::Build(std::string condition, ExpressionPtr truePath, ExpressionPtr falsePath) - { - auto node = std::make_shared(); - node->conditionName = std::move(condition); - node->falsePath = std::move(falsePath); - node->truePath = std::move(truePath); - - return node; - } - - inline Constant::Constant() : - Expression(NodeType::Constant) - { - } - - template - std::shared_ptr Nz::ShaderNodes::Constant::Build(const T& value) - { - auto node = std::make_shared(); - node->value = value; - - return node; - } - - - inline SwizzleOp::SwizzleOp() : - Expression(NodeType::SwizzleOp) - { - } - - inline std::shared_ptr SwizzleOp::Build(ExpressionPtr expressionPtr, SwizzleComponent swizzleComponent) - { - return Build(std::move(expressionPtr), { swizzleComponent }); - } - - inline std::shared_ptr SwizzleOp::Build(ExpressionPtr expressionPtr, std::initializer_list swizzleComponents) - { - auto node = std::make_shared(); - node->componentCount = swizzleComponents.size(); - node->expression = std::move(expressionPtr); - - std::copy(swizzleComponents.begin(), swizzleComponents.end(), node->components.begin()); - - return node; - } - - inline std::shared_ptr SwizzleOp::Build(ExpressionPtr expressionPtr, const SwizzleComponent* components, std::size_t componentCount) - { - auto node = std::make_shared(); - - assert(componentCount < node->components.size()); - - node->componentCount = componentCount; - node->expression = std::move(expressionPtr); - - std::copy(components, components + componentCount, node->components.begin()); - - return node; - } - - - inline Sample2D::Sample2D() : - Expression(NodeType::Sample2D) - { - } - - inline std::shared_ptr Sample2D::Build(ExpressionPtr samplerPtr, ExpressionPtr coordinatesPtr) - { - auto node = std::make_shared(); - node->coordinates = std::move(coordinatesPtr); - node->sampler = std::move(samplerPtr); - - return node; - } - - - inline IntrinsicCall::IntrinsicCall() : - Expression(NodeType::IntrinsicCall) - { - } - - inline std::shared_ptr IntrinsicCall::Build(IntrinsicType intrinsic, std::vector parameters) - { - auto node = std::make_shared(); - node->intrinsic = intrinsic; - node->parameters = std::move(parameters); - - return node; - } } #include diff --git a/include/Nazara/Shader/ShaderVarVisitor.hpp b/include/Nazara/Shader/ShaderVarVisitor.hpp deleted file mode 100644 index babfb2b1e..000000000 --- a/include/Nazara/Shader/ShaderVarVisitor.hpp +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright (C) 2015 Jérôme Leclercq -// This file is part of the "Nazara Engine - Shader generator" -// For conditions of distribution and use, see copyright notice in Config.hpp - -#pragma once - -#ifndef NAZARA_SHADERVARVISITOR_HPP -#define NAZARA_SHADERVARVISITOR_HPP - -#include -#include -#include - -namespace Nz -{ - class NAZARA_SHADER_API ShaderVarVisitor - { - public: - ShaderVarVisitor() = default; - ShaderVarVisitor(const ShaderVarVisitor&) = delete; - ShaderVarVisitor(ShaderVarVisitor&&) = delete; - virtual ~ShaderVarVisitor(); - - void Visit(const ShaderNodes::VariablePtr& node); - - virtual void Visit(ShaderNodes::BuiltinVariable& var) = 0; - virtual void Visit(ShaderNodes::InputVariable& var) = 0; - virtual void Visit(ShaderNodes::LocalVariable& var) = 0; - virtual void Visit(ShaderNodes::OutputVariable& var) = 0; - virtual void Visit(ShaderNodes::ParameterVariable& var) = 0; - virtual void Visit(ShaderNodes::UniformVariable& var) = 0; - - ShaderVarVisitor& operator=(const ShaderVarVisitor&) = delete; - ShaderVarVisitor& operator=(ShaderVarVisitor&&) = delete; - }; -} - -#endif diff --git a/include/Nazara/Shader/ShaderVarVisitorExcept.hpp b/include/Nazara/Shader/ShaderVarVisitorExcept.hpp deleted file mode 100644 index 3fa769e21..000000000 --- a/include/Nazara/Shader/ShaderVarVisitorExcept.hpp +++ /dev/null @@ -1,28 +0,0 @@ -// Copyright (C) 2015 Jérôme Leclercq -// This file is part of the "Nazara Engine - Shader generator" -// For conditions of distribution and use, see copyright notice in Config.hpp - -#pragma once - -#ifndef NAZARA_SHADERVARVISITOREXCEPT_HPP -#define NAZARA_SHADERVARVISITOREXCEPT_HPP - -#include -#include - -namespace Nz -{ - class NAZARA_SHADER_API ShaderVarVisitorExcept : public ShaderVarVisitor - { - public: - using ShaderVarVisitor::Visit; - void Visit(ShaderNodes::BuiltinVariable& var) override; - void Visit(ShaderNodes::InputVariable& var) override; - void Visit(ShaderNodes::LocalVariable& var) override; - void Visit(ShaderNodes::OutputVariable& var) override; - void Visit(ShaderNodes::ParameterVariable& var) override; - void Visit(ShaderNodes::UniformVariable& var) override; - }; -} - -#endif diff --git a/include/Nazara/Shader/ShaderVariables.hpp b/include/Nazara/Shader/ShaderVariables.hpp deleted file mode 100644 index eb0bc8ede..000000000 --- a/include/Nazara/Shader/ShaderVariables.hpp +++ /dev/null @@ -1,128 +0,0 @@ -// Copyright (C) 2020 Jérôme Leclercq -// This file is part of the "Nazara Engine - Shader generator" -// For conditions of distribution and use, see copyright notice in Config.hpp - -#pragma once - -#ifndef NAZARA_SHADER_VARIABLES_HPP -#define NAZARA_SHADER_VARIABLES_HPP - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace Nz -{ - class ShaderVarVisitor; - - namespace ShaderNodes - { - struct Variable; - - using VariablePtr = std::shared_ptr; - - struct NAZARA_SHADER_API Variable : std::enable_shared_from_this - { - virtual ~Variable(); - - virtual VariableType GetType() const = 0; - virtual void Visit(ShaderVarVisitor& visitor) = 0; - - ShaderExpressionType type; - }; - - struct BuiltinVariable; - - using BuiltinVariablePtr = std::shared_ptr; - - struct NAZARA_SHADER_API BuiltinVariable : public Variable - { - BuiltinEntry entry; - - VariableType GetType() const override; - void Visit(ShaderVarVisitor& visitor) override; - - static inline std::shared_ptr Build(BuiltinEntry entry, ShaderExpressionType varType); - }; - - struct NamedVariable; - - using NamedVariablePtr = std::shared_ptr; - - struct NAZARA_SHADER_API NamedVariable : public Variable - { - std::string name; - }; - - struct InputVariable; - - using InputVariablePtr = std::shared_ptr; - - struct NAZARA_SHADER_API InputVariable : public NamedVariable - { - VariableType GetType() const override; - void Visit(ShaderVarVisitor& visitor) override; - - static inline std::shared_ptr Build(std::string varName, ShaderExpressionType varType); - }; - - struct LocalVariable; - - using LocalVariablePtr = std::shared_ptr; - - struct NAZARA_SHADER_API LocalVariable : public NamedVariable - { - VariableType GetType() const override; - void Visit(ShaderVarVisitor& visitor) override; - - static inline std::shared_ptr Build(std::string varName, ShaderExpressionType varType); - }; - - struct OutputVariable; - - using OutputVariablePtr = std::shared_ptr; - - struct NAZARA_SHADER_API OutputVariable : public NamedVariable - { - VariableType GetType() const override; - void Visit(ShaderVarVisitor& visitor) override; - - static inline std::shared_ptr Build(std::string varName, ShaderExpressionType varType); - }; - - struct ParameterVariable; - - using ParameterVariablePtr = std::shared_ptr; - - struct NAZARA_SHADER_API ParameterVariable : public NamedVariable - { - VariableType GetType() const override; - void Visit(ShaderVarVisitor& visitor) override; - - static inline std::shared_ptr Build(std::string varName, ShaderExpressionType varType); - }; - - struct UniformVariable; - - using UniformVariablePtr = std::shared_ptr; - - struct NAZARA_SHADER_API UniformVariable : public NamedVariable - { - VariableType GetType() const override; - void Visit(ShaderVarVisitor& visitor) override; - - static inline std::shared_ptr Build(std::string varName, ShaderExpressionType varType); - }; - } -} - -#include - -#endif diff --git a/include/Nazara/Shader/ShaderVariables.inl b/include/Nazara/Shader/ShaderVariables.inl deleted file mode 100644 index 9f2415708..000000000 --- a/include/Nazara/Shader/ShaderVariables.inl +++ /dev/null @@ -1,65 +0,0 @@ -// Copyright (C) 2020 Jérôme Leclercq -// This file is part of the "Nazara Engine - Shader generator" -// For conditions of distribution and use, see copyright notice in Config.hpp - -#include -#include - -namespace Nz::ShaderNodes -{ - inline std::shared_ptr BuiltinVariable::Build(BuiltinEntry variable, ShaderExpressionType varType) - { - auto node = std::make_shared(); - node->entry = variable; - node->type = varType; - - return node; - } - - inline std::shared_ptr InputVariable::Build(std::string varName, ShaderExpressionType varType) - { - auto node = std::make_shared(); - node->name = std::move(varName); - node->type = varType; - - return node; - } - - inline std::shared_ptr LocalVariable::Build(std::string varName, ShaderExpressionType varType) - { - auto node = std::make_shared(); - node->name = std::move(varName); - node->type = varType; - - return node; - } - - inline std::shared_ptr OutputVariable::Build(std::string varName, ShaderExpressionType varType) - { - auto node = std::make_shared(); - node->name = std::move(varName); - node->type = varType; - - return node; - } - - inline std::shared_ptr ParameterVariable::Build(std::string varName, ShaderExpressionType varType) - { - auto node = std::make_shared(); - node->name = std::move(varName); - node->type = varType; - - return node; - } - - inline std::shared_ptr UniformVariable::Build(std::string varName, ShaderExpressionType varType) - { - auto node = std::make_shared(); - node->name = std::move(varName); - node->type = varType; - - return node; - } -} - -#include diff --git a/include/Nazara/Shader/ShaderWriter.hpp b/include/Nazara/Shader/ShaderWriter.hpp index 337c3c3cf..47fb04a6f 100644 --- a/include/Nazara/Shader/ShaderWriter.hpp +++ b/include/Nazara/Shader/ShaderWriter.hpp @@ -14,8 +14,6 @@ namespace Nz { - class ShaderAst; - class NAZARA_SHADER_API ShaderWriter { public: diff --git a/include/Nazara/Shader/SpirvAstVisitor.hpp b/include/Nazara/Shader/SpirvAstVisitor.hpp index cdf6fa2fe..ffead5fef 100644 --- a/include/Nazara/Shader/SpirvAstVisitor.hpp +++ b/include/Nazara/Shader/SpirvAstVisitor.hpp @@ -9,8 +9,8 @@ #include #include -#include -#include +#include +#include #include #include @@ -18,7 +18,7 @@ namespace Nz { class SpirvWriter; - class NAZARA_SHADER_API SpirvAstVisitor : public ShaderAstVisitorExcept + class NAZARA_SHADER_API SpirvAstVisitor : public ShaderAst::ExpressionVisitorExcept, public ShaderAst::StatementVisitorExcept { public: inline SpirvAstVisitor(SpirvWriter& writer, std::vector& blocks); @@ -26,27 +26,28 @@ namespace Nz SpirvAstVisitor(SpirvAstVisitor&&) = delete; ~SpirvAstVisitor() = default; - UInt32 EvaluateExpression(const ShaderNodes::ExpressionPtr& expr); + UInt32 EvaluateExpression(ShaderAst::ExpressionPtr& expr); - using ShaderAstVisitorExcept::Visit; - void Visit(ShaderNodes::AccessMember& node) override; - void Visit(ShaderNodes::AssignOp& node) override; - void Visit(ShaderNodes::BinaryOp& node) override; - void Visit(ShaderNodes::Branch& node) override; - void Visit(ShaderNodes::Cast& node) override; - void Visit(ShaderNodes::ConditionalExpression& node) override; - void Visit(ShaderNodes::ConditionalStatement& node) override; - void Visit(ShaderNodes::Constant& node) override; - void Visit(ShaderNodes::DeclareVariable& node) override; - void Visit(ShaderNodes::Discard& node) override; - void Visit(ShaderNodes::ExpressionStatement& node) override; - void Visit(ShaderNodes::Identifier& node) override; - void Visit(ShaderNodes::IntrinsicCall& node) override; - void Visit(ShaderNodes::NoOp& node) override; - void Visit(ShaderNodes::ReturnStatement& node) override; - void Visit(ShaderNodes::Sample2D& node) override; - void Visit(ShaderNodes::StatementBlock& node) override; - void Visit(ShaderNodes::SwizzleOp& node) override; + using ExpressionVisitorExcept::Visit; + using StatementVisitorExcept::Visit; + + void Visit(ShaderAst::AccessMemberExpression& node) override; + void Visit(ShaderAst::AssignExpression& node) override; + void Visit(ShaderAst::BinaryExpression& node) override; + void Visit(ShaderAst::BranchStatement& node) override; + void Visit(ShaderAst::CastExpression& node) override; + void Visit(ShaderAst::ConditionalExpression& node) override; + void Visit(ShaderAst::ConditionalStatement& node) override; + void Visit(ShaderAst::ConstantExpression& node) override; + void Visit(ShaderAst::DeclareVariableStatement& node) override; + void Visit(ShaderAst::DiscardStatement& node) override; + void Visit(ShaderAst::ExpressionStatement& node) override; + void Visit(ShaderAst::IdentifierExpression& node) override; + void Visit(ShaderAst::IntrinsicExpression& node) override; + void Visit(ShaderAst::MultiStatement& node) override; + void Visit(ShaderAst::NoOpStatement& node) override; + void Visit(ShaderAst::ReturnStatement& node) override; + void Visit(ShaderAst::SwizzleExpression& node) override; SpirvAstVisitor& operator=(const SpirvAstVisitor&) = delete; SpirvAstVisitor& operator=(SpirvAstVisitor&&) = delete; diff --git a/include/Nazara/Shader/SpirvConstantCache.hpp b/include/Nazara/Shader/SpirvConstantCache.hpp index 279f71b07..54a53584c 100644 --- a/include/Nazara/Shader/SpirvConstantCache.hpp +++ b/include/Nazara/Shader/SpirvConstantCache.hpp @@ -10,7 +10,7 @@ #include #include #include -#include +#include #include #include #include @@ -20,7 +20,6 @@ namespace Nz { - class ShaderAst; class SpirvSection; class NAZARA_SHADER_API SpirvConstantCache @@ -173,10 +172,10 @@ namespace Nz SpirvConstantCache& operator=(SpirvConstantCache&& cache) noexcept; static ConstantPtr BuildConstant(const ShaderConstantValue& value); - static TypePtr BuildPointerType(const ShaderNodes::BasicType& type, SpirvStorageClass storageClass); - static TypePtr BuildPointerType(const ShaderAst& shader, const ShaderExpressionType& type, SpirvStorageClass storageClass); - static TypePtr BuildType(const ShaderNodes::BasicType& type); - static TypePtr BuildType(const ShaderAst& shader, const ShaderExpressionType& type); + static TypePtr BuildPointerType(const ShaderAst::BasicType& type, SpirvStorageClass storageClass); + static TypePtr BuildPointerType(const ShaderAst::ShaderExpressionType& type, SpirvStorageClass storageClass); + static TypePtr BuildType(const ShaderAst::BasicType& type); + static TypePtr BuildType(const ShaderAst::ShaderExpressionType& type); private: struct DepRegisterer; diff --git a/include/Nazara/Shader/SpirvExpressionLoad.hpp b/include/Nazara/Shader/SpirvExpressionLoad.hpp index f59369a52..ddc9551fd 100644 --- a/include/Nazara/Shader/SpirvExpressionLoad.hpp +++ b/include/Nazara/Shader/SpirvExpressionLoad.hpp @@ -9,8 +9,7 @@ #include #include -#include -#include +#include #include #include @@ -19,7 +18,7 @@ namespace Nz class SpirvBlock; class SpirvWriter; - class NAZARA_SHADER_API SpirvExpressionLoad : public ShaderAstVisitorExcept, public ShaderVarVisitorExcept + class NAZARA_SHADER_API SpirvExpressionLoad : public ShaderAst::ExpressionVisitorExcept { public: inline SpirvExpressionLoad(SpirvWriter& writer, SpirvBlock& block); @@ -27,17 +26,11 @@ namespace Nz SpirvExpressionLoad(SpirvExpressionLoad&&) = delete; ~SpirvExpressionLoad() = default; - UInt32 Evaluate(ShaderNodes::Expression& node); + UInt32 Evaluate(ShaderAst::Expression& node); - using ShaderAstVisitor::Visit; - void Visit(ShaderNodes::AccessMember& node) override; - void Visit(ShaderNodes::Identifier& node) override; - - using ShaderVarVisitor::Visit; - void Visit(ShaderNodes::InputVariable& var) override; - void Visit(ShaderNodes::LocalVariable& var) override; - void Visit(ShaderNodes::ParameterVariable& var) override; - void Visit(ShaderNodes::UniformVariable& var) override; + using ExpressionVisitorExcept::Visit; + //void Visit(ShaderAst::AccessMemberExpression& node) override; + void Visit(ShaderAst::IdentifierExpression& node) override; SpirvExpressionLoad& operator=(const SpirvExpressionLoad&) = delete; SpirvExpressionLoad& operator=(SpirvExpressionLoad&&) = delete; diff --git a/include/Nazara/Shader/SpirvExpressionStore.hpp b/include/Nazara/Shader/SpirvExpressionStore.hpp index 26c2b5f48..ee0d96f6a 100644 --- a/include/Nazara/Shader/SpirvExpressionStore.hpp +++ b/include/Nazara/Shader/SpirvExpressionStore.hpp @@ -9,8 +9,7 @@ #include #include -#include -#include +#include #include namespace Nz @@ -18,7 +17,7 @@ namespace Nz class SpirvBlock; class SpirvWriter; - class NAZARA_SHADER_API SpirvExpressionStore : public ShaderAstVisitorExcept, public ShaderVarVisitorExcept + class NAZARA_SHADER_API SpirvExpressionStore : public ShaderAst::ExpressionVisitorExcept { public: inline SpirvExpressionStore(SpirvWriter& writer, SpirvBlock& block); @@ -26,17 +25,12 @@ namespace Nz SpirvExpressionStore(SpirvExpressionStore&&) = delete; ~SpirvExpressionStore() = default; - void Store(const ShaderNodes::ExpressionPtr& node, UInt32 resultId); + void Store(ShaderAst::ExpressionPtr& node, UInt32 resultId); - using ShaderAstVisitorExcept::Visit; - void Visit(ShaderNodes::AccessMember& node) override; - void Visit(ShaderNodes::Identifier& node) override; - void Visit(ShaderNodes::SwizzleOp& node) override; - - using ShaderVarVisitorExcept::Visit; - void Visit(ShaderNodes::BuiltinVariable& var) override; - void Visit(ShaderNodes::LocalVariable& var) override; - void Visit(ShaderNodes::OutputVariable& var) override; + using ExpressionVisitorExcept::Visit; + //void Visit(ShaderAst::AccessMemberExpression& node) override; + void Visit(ShaderAst::IdentifierExpression& node) override; + void Visit(ShaderAst::SwizzleExpression& node) override; SpirvExpressionStore& operator=(const SpirvExpressionStore&) = delete; SpirvExpressionStore& operator=(SpirvExpressionStore&&) = delete; diff --git a/include/Nazara/Shader/SpirvWriter.hpp b/include/Nazara/Shader/SpirvWriter.hpp index eeaee1e60..a8af651cd 100644 --- a/include/Nazara/Shader/SpirvWriter.hpp +++ b/include/Nazara/Shader/SpirvWriter.hpp @@ -9,10 +9,8 @@ #include #include -#include -#include +#include #include -#include #include #include #include @@ -39,7 +37,7 @@ namespace Nz SpirvWriter(SpirvWriter&&) = delete; ~SpirvWriter() = default; - std::vector Generate(const ShaderAst& shader, const States& conditions = {}); + std::vector Generate(ShaderAst::StatementPtr& shader, const States& conditions = {}); void SetEnv(Environment environment); @@ -51,22 +49,23 @@ namespace Nz private: struct ExtVar; + struct FunctionParameter; struct OnlyCache {}; UInt32 AllocateResultId(); void AppendHeader(); - SpirvConstantCache::Function BuildFunctionType(ShaderExpressionType retType, const std::vector& parameters); + SpirvConstantCache::Function BuildFunctionType(ShaderAst::ShaderExpressionType retType, const std::vector& parameters); UInt32 GetConstantId(const ShaderConstantValue& value) const; - UInt32 GetFunctionTypeId(ShaderExpressionType retType, const std::vector& parameters); - const ExtVar& GetBuiltinVariable(ShaderNodes::BuiltinEntry builtin) const; + UInt32 GetFunctionTypeId(ShaderAst::ShaderExpressionType retType, const std::vector& parameters); + const ExtVar& GetBuiltinVariable(ShaderAst::BuiltinEntry builtin) const; const ExtVar& GetInputVariable(const std::string& name) const; const ExtVar& GetOutputVariable(const std::string& name) const; const ExtVar& GetUniformVariable(const std::string& name) const; - UInt32 GetPointerTypeId(const ShaderExpressionType& type, SpirvStorageClass storageClass) const; - UInt32 GetTypeId(const ShaderExpressionType& type) const; + UInt32 GetPointerTypeId(const ShaderAst::ShaderExpressionType& type, SpirvStorageClass storageClass) const; + UInt32 GetTypeId(const ShaderAst::ShaderExpressionType& type) const; inline bool IsConditionEnabled(const std::string& condition) const; @@ -82,9 +81,9 @@ namespace Nz std::optional ReadVariable(const ExtVar& var, OnlyCache); UInt32 RegisterConstant(const ShaderConstantValue& value); - UInt32 RegisterFunctionType(ShaderExpressionType retType, const std::vector& parameters); - UInt32 RegisterPointerType(ShaderExpressionType type, SpirvStorageClass storageClass); - UInt32 RegisterType(ShaderExpressionType type); + UInt32 RegisterFunctionType(ShaderAst::ShaderExpressionType retType, const std::vector& parameters); + UInt32 RegisterPointerType(ShaderAst::ShaderExpressionType type, SpirvStorageClass storageClass); + UInt32 RegisterType(ShaderAst::ShaderExpressionType type); void WriteLocalVariable(std::string name, UInt32 resultId); @@ -92,7 +91,7 @@ namespace Nz struct Context { - const ShaderAst* shader = nullptr; + ShaderAst::AstCache cache; const States* states = nullptr; std::vector functionBlocks; }; @@ -105,6 +104,12 @@ namespace Nz std::optional valueId; }; + struct FunctionParameter + { + std::string name; + ShaderAst::ShaderExpressionType type; + }; + struct State; Context m_context; diff --git a/include/Nazara/Shader/SpirvWriter.inl b/include/Nazara/Shader/SpirvWriter.inl index ed518e5f4..903d6265b 100644 --- a/include/Nazara/Shader/SpirvWriter.inl +++ b/include/Nazara/Shader/SpirvWriter.inl @@ -10,10 +10,11 @@ namespace Nz { inline bool SpirvWriter::IsConditionEnabled(const std::string& condition) const { - std::size_t conditionIndex = m_context.shader->FindConditionByName(condition); + /*std::size_t conditionIndex = m_context.shader->FindConditionByName(condition); assert(conditionIndex != ShaderAst::InvalidCondition); - return TestBit(m_context.states->enabledConditions, conditionIndex); + return TestBit(m_context.states->enabledConditions, conditionIndex);*/ + return false; } } diff --git a/src/Nazara/Graphics/BasicMaterial.cpp b/src/Nazara/Graphics/BasicMaterial.cpp index 089438336..220efd6bb 100644 --- a/src/Nazara/Graphics/BasicMaterial.cpp +++ b/src/Nazara/Graphics/BasicMaterial.cpp @@ -167,8 +167,8 @@ namespace Nz auto& fragmentShader = settings.shaders[UnderlyingCast(ShaderStageType::Fragment)]; auto& vertexShader = settings.shaders[UnderlyingCast(ShaderStageType::Vertex)]; - fragmentShader = std::make_shared(UnserializeShader(r_fragmentShader, sizeof(r_fragmentShader))); - vertexShader = std::make_shared(UnserializeShader(r_vertexShader, sizeof(r_vertexShader))); + fragmentShader = std::make_shared(ShaderAst::UnserializeShader(r_fragmentShader, sizeof(r_fragmentShader))); + vertexShader = std::make_shared(ShaderAst::UnserializeShader(r_vertexShader, sizeof(r_vertexShader))); // Conditions diff --git a/src/Nazara/Graphics/UberShader.cpp b/src/Nazara/Graphics/UberShader.cpp index 6fb0fcba0..fc2364a24 100644 --- a/src/Nazara/Graphics/UberShader.cpp +++ b/src/Nazara/Graphics/UberShader.cpp @@ -5,17 +5,17 @@ #include #include #include -#include #include #include #include namespace Nz { - UberShader::UberShader(ShaderAst shaderAst) : + UberShader::UberShader(ShaderAst::StatementPtr shaderAst) : m_shaderAst(std::move(shaderAst)) { - std::size_t conditionCount = m_shaderAst.GetConditionCount(); + //std::size_t conditionCount = m_shaderAst.GetConditionCount(); + std::size_t conditionCount = 0; if (conditionCount >= 64) throw std::runtime_error("Too many conditions"); @@ -27,10 +27,10 @@ namespace Nz UInt64 UberShader::GetConditionFlagByName(const std::string_view& condition) const { - std::size_t conditionIndex = m_shaderAst.FindConditionByName(condition); + /*std::size_t conditionIndex = m_shaderAst.FindConditionByName(condition); if (conditionIndex != ShaderAst::InvalidCondition) return SetBit(0, conditionIndex); - else + else*/ return 0; } diff --git a/src/Nazara/Shader/GlslWriter.cpp b/src/Nazara/Shader/GlslWriter.cpp index 1920bf291..bb9ce8433 100644 --- a/src/Nazara/Shader/GlslWriter.cpp +++ b/src/Nazara/Shader/GlslWriter.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #include #include @@ -20,64 +21,67 @@ namespace Nz { static const char* flipYUniformName = "_NzFlipValue"; - struct AstAdapter : ShaderAstCloner + struct AstAdapter : ShaderAst::AstCloner { - void Visit(ShaderNodes::AssignOp& node) override + void Visit(ShaderAst::AssignExpression& node) override { if (!flipYPosition) + return AstCloner::Visit(node); + + if (node.left->GetType() != ShaderAst::NodeType::IdentifierExpression) + return AstCloner::Visit(node); + + /* + FIXME: + const auto& identifier = static_cast(*node.left); + if (identifier.var->GetType() != ShaderAst::VariableType::BuiltinVariable) return ShaderAstCloner::Visit(node); - if (node.left->GetType() != ShaderNodes::NodeType::Identifier) + const auto& builtinVar = static_cast(*identifier.var); + if (builtinVar.entry != ShaderAst::BuiltinEntry::VertexPosition) return ShaderAstCloner::Visit(node); - const auto& identifier = static_cast(*node.left); - if (identifier.var->GetType() != ShaderNodes::VariableType::BuiltinVariable) - return ShaderAstCloner::Visit(node); - - const auto& builtinVar = static_cast(*identifier.var); - if (builtinVar.entry != ShaderNodes::BuiltinEntry::VertexPosition) - return ShaderAstCloner::Visit(node); - - auto flipVar = ShaderBuilder::Uniform(flipYUniformName, ShaderNodes::BasicType::Float1); + auto flipVar = ShaderBuilder::Uniform(flipYUniformName, ShaderAst::BasicType::Float1); auto oneConstant = ShaderBuilder::Constant(1.f); - auto fixYValue = ShaderBuilder::Cast(oneConstant, ShaderBuilder::Identifier(flipVar), oneConstant, oneConstant); + auto fixYValue = ShaderBuilder::Cast(oneConstant, ShaderBuilder::Identifier(flipVar), oneConstant, oneConstant); auto mulFix = ShaderBuilder::Multiply(CloneExpression(node.right), fixYValue); - PushExpression(ShaderNodes::AssignOp::Build(node.op, CloneExpression(node.left), mulFix)); + PushExpression(ShaderAst::AssignOp::Build(node.op, CloneExpression(node.left), mulFix));*/ } bool flipYPosition = false; }; } + + struct GlslWriter::State + { + const States* states = nullptr; + ShaderAst::AstCache cache; + std::stringstream stream; + unsigned int indentLevel = 0; + }; + + GlslWriter::GlslWriter() : m_currentState(nullptr) { } - std::string GlslWriter::Generate(const ShaderAst& inputShader, const States& conditions) + std::string GlslWriter::Generate(ShaderAst::StatementPtr& shader, const States& conditions) { - const ShaderAst* selectedShader = &inputShader; + /*const ShaderAst* selectedShader = &inputShader; std::optional modifiedShader; if (inputShader.GetStage() == ShaderStageType::Vertex && m_environment.flipYPosition) { modifiedShader.emplace(inputShader); - modifiedShader->AddUniform(flipYUniformName, ShaderNodes::BasicType::Float1); + modifiedShader->AddUniform(flipYUniformName, ShaderAst::BasicType::Float1); selectedShader = &modifiedShader.value(); - } - - const ShaderAst& shader = *selectedShader; - - std::string error; - if (!ValidateShader(shader, &error)) - throw std::runtime_error("Invalid shader AST: " + error); - - m_context.states = &conditions; - m_context.shader = &shader; - + }*/ + State state; m_currentState = &state; CallOnExit onExit([this]() @@ -85,6 +89,10 @@ namespace Nz m_currentState = nullptr; }); + std::string error; + if (!ShaderAst::ValidateAst(shader, &error, &state.cache)) + throw std::runtime_error("Invalid shader AST: " + error); + unsigned int glslVersion; if (m_environment.glES) { @@ -165,52 +173,7 @@ namespace Nz AppendLine(); } - // Structures - /*if (shader.GetStructCount() > 0) - { - AppendCommentSection("Structures"); - for (const auto& s : shader.GetStructs()) - { - Append("struct "); - AppendLine(s.name); - AppendLine("{"); - for (const auto& m : s.members) - { - Append("\t"); - Append(m.type); - Append(" "); - Append(m.name); - AppendLine(";"); - } - AppendLine("};"); - AppendLine(); - } - }*/ - - // Global variables (uniforms, input and outputs) - const char* inKeyword = (glslVersion >= 130) ? "in" : "varying"; - const char* outKeyword = (glslVersion >= 130) ? "out" : "varying"; - - DeclareVariables(shader, shader.GetUniforms(), "uniform", "Uniforms"); - DeclareVariables(shader, shader.GetInputs(), inKeyword, "Inputs"); - DeclareVariables(shader, shader.GetOutputs(), outKeyword, "Outputs"); - - std::size_t functionCount = shader.GetFunctionCount(); - if (functionCount > 1) - { - AppendCommentSection("Prototypes"); - for (const auto& func : shader.GetFunctions()) - { - if (func.name != "main") - { - AppendFunctionPrototype(func); - AppendLine(";"); - } - } - } - - for (const auto& func : shader.GetFunctions()) - AppendFunction(func); + shader->Visit(*this); return state.stream.str(); } @@ -225,7 +188,7 @@ namespace Nz return flipYUniformName; } - void GlslWriter::Append(ShaderExpressionType type) + void GlslWriter::Append(ShaderAst::ShaderExpressionType type) { std::visit([&](auto&& arg) { @@ -233,49 +196,57 @@ namespace Nz }, type); } - void GlslWriter::Append(ShaderNodes::BuiltinEntry builtin) + void GlslWriter::Append(ShaderAst::BuiltinEntry builtin) { switch (builtin) { - case ShaderNodes::BuiltinEntry::VertexPosition: + case ShaderAst::BuiltinEntry::VertexPosition: Append("gl_Position"); break; } } - void GlslWriter::Append(ShaderNodes::BasicType type) + void GlslWriter::Append(ShaderAst::BasicType type) { switch (type) { - case ShaderNodes::BasicType::Boolean: return Append("bool"); - case ShaderNodes::BasicType::Float1: return Append("float"); - case ShaderNodes::BasicType::Float2: return Append("vec2"); - case ShaderNodes::BasicType::Float3: return Append("vec3"); - case ShaderNodes::BasicType::Float4: return Append("vec4"); - case ShaderNodes::BasicType::Int1: return Append("int"); - case ShaderNodes::BasicType::Int2: return Append("ivec2"); - case ShaderNodes::BasicType::Int3: return Append("ivec3"); - case ShaderNodes::BasicType::Int4: return Append("ivec4"); - case ShaderNodes::BasicType::Mat4x4: return Append("mat4"); - case ShaderNodes::BasicType::Sampler2D: return Append("sampler2D"); - case ShaderNodes::BasicType::UInt1: return Append("uint"); - case ShaderNodes::BasicType::UInt2: return Append("uvec2"); - case ShaderNodes::BasicType::UInt3: return Append("uvec3"); - case ShaderNodes::BasicType::UInt4: return Append("uvec4"); - case ShaderNodes::BasicType::Void: return Append("void"); + case ShaderAst::BasicType::Boolean: return Append("bool"); + case ShaderAst::BasicType::Float1: return Append("float"); + case ShaderAst::BasicType::Float2: return Append("vec2"); + case ShaderAst::BasicType::Float3: return Append("vec3"); + case ShaderAst::BasicType::Float4: return Append("vec4"); + case ShaderAst::BasicType::Int1: return Append("int"); + case ShaderAst::BasicType::Int2: return Append("ivec2"); + case ShaderAst::BasicType::Int3: return Append("ivec3"); + case ShaderAst::BasicType::Int4: return Append("ivec4"); + case ShaderAst::BasicType::Mat4x4: return Append("mat4"); + case ShaderAst::BasicType::Sampler2D: return Append("sampler2D"); + case ShaderAst::BasicType::UInt1: return Append("uint"); + case ShaderAst::BasicType::UInt2: return Append("uvec2"); + case ShaderAst::BasicType::UInt3: return Append("uvec3"); + case ShaderAst::BasicType::UInt4: return Append("uvec4"); + case ShaderAst::BasicType::Void: return Append("void"); } } - void GlslWriter::Append(ShaderNodes::MemoryLayout layout) + void GlslWriter::Append(ShaderAst::MemoryLayout layout) { switch (layout) { - case ShaderNodes::MemoryLayout::Std140: + case ShaderAst::MemoryLayout::Std140: Append("std140"); break; } } + template + void GlslWriter::Append(const T& param) + { + NazaraAssert(m_currentState, "This function should only be called while processing an AST"); + + m_currentState->stream << param; + } + void GlslWriter::AppendCommentSection(const std::string& section) { NazaraAssert(m_currentState, "This function should only be called while processing an AST"); @@ -285,67 +256,24 @@ namespace Nz AppendLine(); } - void GlslWriter::AppendField(const std::string& structName, std::size_t* memberIndex, std::size_t remainingMembers) + void GlslWriter::AppendField(std::size_t scopeId, const std::string& structName, const std::string* memberIdentifier, std::size_t remainingMembers) { - const auto& structs = m_context.shader->GetStructs(); - auto it = std::find_if(structs.begin(), structs.end(), [&](const auto& s) { return s.name == structName; }); - assert(it != structs.end()); - - const ShaderAst::Struct& s = *it; - assert(*memberIndex < s.members.size()); - - const auto& member = s.members[*memberIndex]; Append("."); - Append(member.name); + Append(memberIdentifier[0]); + + const ShaderAst::AstCache::Identifier* identifier = m_currentState->cache.FindIdentifier(scopeId, structName); + assert(identifier); + + assert(std::holds_alternative(identifier->value)); + const auto& s = std::get(identifier->value); + + auto memberIt = std::find_if(s.members.begin(), s.members.begin(), [&](const auto& field) { return field.name == memberIdentifier[0]; }); + assert(memberIt != s.members.end()); + + const auto& member = *memberIt; if (remainingMembers > 1) - { - assert(IsStructType(member.type)); - AppendField(std::get(member.type), memberIndex + 1, remainingMembers - 1); - } - } - - void GlslWriter::AppendFunction(const ShaderAst::Function& func) - { - NazaraAssert(!m_context.currentFunction, "A function is already being processed"); - NazaraAssert(m_currentState, "This function should only be called while processing an AST"); - - AppendFunctionPrototype(func); - - m_context.currentFunction = &func; - CallOnExit onExit([this] () - { - m_context.currentFunction = nullptr; - }); - - EnterScope(); - { - AstAdapter adapter; - adapter.flipYPosition = m_environment.flipYPosition; - - Visit(adapter.Clone(func.statement)); - } - LeaveScope(); - } - - void GlslWriter::AppendFunctionPrototype(const ShaderAst::Function& func) - { - Append(func.returnType); - - Append(" "); - Append(func.name); - - Append("("); - for (std::size_t i = 0; i < func.parameters.size(); ++i) - { - if (i != 0) - Append(", "); - - Append(func.parameters[i].type); - Append(" "); - Append(func.parameters[i].name); - } - Append(")\n"); + AppendField(scopeId, std::get(member.type), memberIdentifier + 1, remainingMembers - 1); } void GlslWriter::AppendLine(const std::string& txt) @@ -372,44 +300,46 @@ namespace Nz AppendLine("}"); } - void GlslWriter::Visit(ShaderNodes::ExpressionPtr& expr, bool encloseIfRequired) + void GlslWriter::Visit(ShaderAst::ExpressionPtr& expr, bool encloseIfRequired) { - bool enclose = encloseIfRequired && (GetExpressionCategory(expr) != ShaderNodes::ExpressionCategory::LValue); + bool enclose = encloseIfRequired && (GetExpressionCategory(*expr) != ShaderAst::ExpressionCategory::LValue); if (enclose) Append("("); - ShaderAstVisitor::Visit(expr); + expr->Visit(*this); if (enclose) Append(")"); } - void GlslWriter::Visit(ShaderNodes::AccessMember& node) + void GlslWriter::Visit(ShaderAst::AccessMemberExpression& node) { Visit(node.structExpr, true); - const ShaderExpressionType& exprType = node.structExpr->GetExpressionType(); + const ShaderAst::ShaderExpressionType& exprType = GetExpressionType(*node.structExpr, &m_currentState->cache); assert(IsStructType(exprType)); - AppendField(std::get(exprType), node.memberIndices.data(), node.memberIndices.size()); + std::size_t scopeId = m_currentState->cache.GetScopeId(&node); + + AppendField(scopeId, std::get(exprType), node.memberIdentifiers.data(), node.memberIdentifiers.size()); } - void GlslWriter::Visit(ShaderNodes::AssignOp& node) + void GlslWriter::Visit(ShaderAst::AssignExpression& node) { - Visit(node.left); + node.left->Visit(*this); switch (node.op) { - case ShaderNodes::AssignType::Simple: + case ShaderAst::AssignType::Simple: Append(" = "); break; } - Visit(node.right); + node.left->Visit(*this); } - void GlslWriter::Visit(ShaderNodes::Branch& node) + void GlslWriter::Visit(ShaderAst::BranchStatement& node) { bool first = true; for (const auto& statement : node.condStatements) @@ -418,11 +348,11 @@ namespace Nz Append("else "); Append("if ("); - Visit(statement.condition); + statement.condition->Visit(*this); AppendLine(")"); EnterScope(); - Visit(statement.statement); + statement.statement->Visit(*this); LeaveScope(); first = false; @@ -433,41 +363,36 @@ namespace Nz AppendLine("else"); EnterScope(); - Visit(node.elseStatement); + node.elseStatement->Visit(*this); LeaveScope(); } } - void GlslWriter::Visit(ShaderNodes::BinaryOp& node) + void GlslWriter::Visit(ShaderAst::BinaryExpression& node) { Visit(node.left, true); switch (node.op) { - case ShaderNodes::BinaryType::Add: Append(" + "); break; - case ShaderNodes::BinaryType::Subtract: Append(" - "); break; - case ShaderNodes::BinaryType::Multiply: Append(" * "); break; - case ShaderNodes::BinaryType::Divide: Append(" / "); break; + case ShaderAst::BinaryType::Add: Append(" + "); break; + case ShaderAst::BinaryType::Subtract: Append(" - "); break; + case ShaderAst::BinaryType::Multiply: Append(" * "); break; + case ShaderAst::BinaryType::Divide: Append(" / "); break; - case ShaderNodes::BinaryType::CompEq: Append(" == "); break; - case ShaderNodes::BinaryType::CompGe: Append(" >= "); break; - case ShaderNodes::BinaryType::CompGt: Append(" > "); break; - case ShaderNodes::BinaryType::CompLe: Append(" <= "); break; - case ShaderNodes::BinaryType::CompLt: Append(" < "); break; - case ShaderNodes::BinaryType::CompNe: Append(" != "); break; + case ShaderAst::BinaryType::CompEq: Append(" == "); break; + case ShaderAst::BinaryType::CompGe: Append(" >= "); break; + case ShaderAst::BinaryType::CompGt: Append(" > "); break; + case ShaderAst::BinaryType::CompLe: Append(" <= "); break; + case ShaderAst::BinaryType::CompLt: Append(" < "); break; + case ShaderAst::BinaryType::CompNe: Append(" != "); break; } Visit(node.right, true); } - void GlslWriter::Visit(ShaderNodes::BuiltinVariable& var) + void GlslWriter::Visit(ShaderAst::CastExpression& node) { - Append(var.entry); - } - - void GlslWriter::Visit(ShaderNodes::Cast& node) - { - Append(node.exprType); + Append(node.targetType); Append("("); bool first = true; @@ -479,34 +404,34 @@ namespace Nz if (!first) m_currentState->stream << ", "; - Visit(exprPtr); + exprPtr->Visit(*this); first = false; } Append(")"); } - void GlslWriter::Visit(ShaderNodes::ConditionalExpression& node) + void GlslWriter::Visit(ShaderAst::ConditionalExpression& node) { - std::size_t conditionIndex = m_context.shader->FindConditionByName(node.conditionName); + /*std::size_t conditionIndex = m_context.shader->FindConditionByName(node.conditionName); assert(conditionIndex != ShaderAst::InvalidCondition); if (TestBit(m_context.states->enabledConditions, conditionIndex)) Visit(node.truePath); else - Visit(node.falsePath); + Visit(node.falsePath);*/ } - void GlslWriter::Visit(ShaderNodes::ConditionalStatement& node) + void GlslWriter::Visit(ShaderAst::ConditionalStatement& node) { - std::size_t conditionIndex = m_context.shader->FindConditionByName(node.conditionName); + /*std::size_t conditionIndex = m_context.shader->FindConditionByName(node.conditionName); assert(conditionIndex != ShaderAst::InvalidCondition); if (TestBit(m_context.states->enabledConditions, conditionIndex)) - Visit(node.statement); + Visit(node.statement);*/ } - void GlslWriter::Visit(ShaderNodes::Constant& node) + void GlslWriter::Visit(ShaderAst::ConstantExpression& node) { std::visit([&](auto&& arg) { @@ -530,54 +455,74 @@ namespace Nz }, node.value); } - void GlslWriter::Visit(ShaderNodes::DeclareVariable& node) + void GlslWriter::Visit(ShaderAst::DeclareFunctionStatement& node) { - assert(node.variable->GetType() == ShaderNodes::VariableType::LocalVariable); + NazaraAssert(m_currentState, "This function should only be called while processing an AST"); - const auto& localVar = static_cast(*node.variable); - - Append(localVar.type); + Append(node.returnType); Append(" "); - Append(localVar.name); - if (node.expression) + Append(node.name); + Append("("); + for (std::size_t i = 0; i < node.parameters.size(); ++i) + { + if (i != 0) + Append(", "); + Append(node.parameters[i].type); + Append(" "); + Append(node.parameters[i].name); + } + Append(")\n"); + + EnterScope(); + { + AstAdapter adapter; + adapter.flipYPosition = m_environment.flipYPosition; + + for (auto& statement : node.statements) + adapter.Clone(statement)->Visit(*this); + } + LeaveScope(); + } + + void GlslWriter::Visit(ShaderAst::DeclareVariableStatement& node) + { + Append(node.varType); + Append(" "); + Append(node.varName); + if (node.initialExpression) { Append(" = "); - Visit(node.expression); + node.initialExpression->Visit(*this); } AppendLine(";"); } - void GlslWriter::Visit(ShaderNodes::Discard& /*node*/) + void GlslWriter::Visit(ShaderAst::DiscardStatement& /*node*/) { Append("discard;"); } - void GlslWriter::Visit(ShaderNodes::ExpressionStatement& node) + void GlslWriter::Visit(ShaderAst::ExpressionStatement& node) { - Visit(node.expression); + node.expression->Visit(*this); Append(";"); } - void GlslWriter::Visit(ShaderNodes::Identifier& node) + void GlslWriter::Visit(ShaderAst::IdentifierExpression& node) { - Visit(node.var); + Append(node.identifier); } - void GlslWriter::Visit(ShaderNodes::InputVariable& var) - { - Append(var.name); - } - - void GlslWriter::Visit(ShaderNodes::IntrinsicCall& node) + void GlslWriter::Visit(ShaderAst::IntrinsicExpression& node) { switch (node.intrinsic) { - case ShaderNodes::IntrinsicType::CrossProduct: + case ShaderAst::IntrinsicType::CrossProduct: Append("cross"); break; - case ShaderNodes::IntrinsicType::DotProduct: + case ShaderAst::IntrinsicType::DotProduct: Append("dot"); break; } @@ -588,67 +533,43 @@ namespace Nz if (i != 0) Append(", "); - Visit(node.parameters[i]); + node.parameters[i]->Visit(*this); } Append(")"); } - void GlslWriter::Visit(ShaderNodes::LocalVariable& var) + void GlslWriter::Visit(ShaderAst::MultiStatement& node) { - Append(var.name); + bool first = true; + for (const ShaderAst::StatementPtr& statement : node.statements) + { + if (!first && statement->GetType() != ShaderAst::NodeType::NoOpStatement) + AppendLine(); + + statement->Visit(*this); + + first = false; + } } - void GlslWriter::Visit(ShaderNodes::NoOp& /*node*/) + void GlslWriter::Visit(ShaderAst::NoOpStatement& /*node*/) { /* nothing to do */ } - void GlslWriter::Visit(ShaderNodes::ParameterVariable& var) - { - Append(var.name); - } - - void GlslWriter::Visit(ShaderNodes::ReturnStatement& node) + void GlslWriter::Visit(ShaderAst::ReturnStatement& node) { if (node.returnExpr) { Append("return "); - Visit(node.returnExpr); + node.returnExpr->Visit(*this); Append(";"); } else Append("return;"); } - void GlslWriter::Visit(ShaderNodes::OutputVariable& var) - { - Append(var.name); - } - - void GlslWriter::Visit(ShaderNodes::Sample2D& node) - { - Append("texture("); - Visit(node.sampler); - Append(", "); - Visit(node.coordinates); - Append(")"); - } - - void GlslWriter::Visit(ShaderNodes::StatementBlock& node) - { - bool first = true; - for (const ShaderNodes::StatementPtr& statement : node.statements) - { - if (!first && statement->GetType() != ShaderNodes::NodeType::NoOp) - AppendLine(); - - Visit(statement); - - first = false; - } - } - - void GlslWriter::Visit(ShaderNodes::SwizzleOp& node) + void GlslWriter::Visit(ShaderAst::SwizzleExpression& node) { Visit(node.expression, true); Append("."); @@ -657,44 +578,39 @@ namespace Nz { switch (node.components[i]) { - case ShaderNodes::SwizzleComponent::First: + case ShaderAst::SwizzleComponent::First: Append("x"); break; - case ShaderNodes::SwizzleComponent::Second: + case ShaderAst::SwizzleComponent::Second: Append("y"); break; - case ShaderNodes::SwizzleComponent::Third: + case ShaderAst::SwizzleComponent::Third: Append("z"); break; - case ShaderNodes::SwizzleComponent::Fourth: + case ShaderAst::SwizzleComponent::Fourth: Append("w"); break; } } } - void GlslWriter::Visit(ShaderNodes::UniformVariable& var) + bool GlslWriter::HasExplicitBinding(ShaderAst::StatementPtr& shader) { - Append(var.name); - } - - bool GlslWriter::HasExplicitBinding(const ShaderAst& shader) - { - for (const auto& uniform : shader.GetUniforms()) + /*for (const auto& uniform : shader.GetUniforms()) { if (uniform.bindingIndex.has_value()) return true; - } + }*/ return false; } - bool GlslWriter::HasExplicitLocation(const ShaderAst& shader) + bool GlslWriter::HasExplicitLocation(ShaderAst::StatementPtr& shader) { - for (const auto& input : shader.GetInputs()) + /*for (const auto& input : shader.GetInputs()) { if (input.locationIndex.has_value()) return true; @@ -704,7 +620,7 @@ namespace Nz { if (output.locationIndex.has_value()) return true; - } + }*/ return false; } diff --git a/src/Nazara/Shader/ShaderAst.cpp b/src/Nazara/Shader/ShaderAst.cpp deleted file mode 100644 index ff84cfa55..000000000 --- a/src/Nazara/Shader/ShaderAst.cpp +++ /dev/null @@ -1,56 +0,0 @@ -// Copyright (C) 2020 Jérôme Leclercq -// This file is part of the "Nazara Engine - Shader generator" -// For conditions of distribution and use, see copyright notice in Config.hpp - -#include -#include - -namespace Nz -{ - void ShaderAst::AddCondition(std::string name) - { - auto& conditionEntry = m_conditions.emplace_back(); - conditionEntry.name = std::move(name); - } - - void ShaderAst::AddFunction(std::string name, ShaderNodes::StatementPtr statement, std::vector parameters, ShaderExpressionType returnType) - { - auto& functionEntry = m_functions.emplace_back(); - functionEntry.name = std::move(name); - functionEntry.parameters = std::move(parameters); - functionEntry.returnType = returnType; - functionEntry.statement = std::move(statement); - } - - void ShaderAst::AddInput(std::string name, ShaderExpressionType type, std::optional locationIndex) - { - auto& inputEntry = m_inputs.emplace_back(); - inputEntry.name = std::move(name); - inputEntry.locationIndex = std::move(locationIndex); - inputEntry.type = std::move(type); - } - - void ShaderAst::AddOutput(std::string name, ShaderExpressionType type, std::optional locationIndex) - { - auto& outputEntry = m_outputs.emplace_back(); - outputEntry.name = std::move(name); - outputEntry.locationIndex = std::move(locationIndex); - outputEntry.type = std::move(type); - } - - void ShaderAst::AddStruct(std::string name, std::vector members) - { - auto& structEntry = m_structs.emplace_back(); - structEntry.name = std::move(name); - structEntry.members = std::move(members); - } - - void ShaderAst::AddUniform(std::string name, ShaderExpressionType type, std::optional bindingIndex, std::optional memoryLayout) - { - auto& uniformEntry = m_uniforms.emplace_back(); - uniformEntry.bindingIndex = std::move(bindingIndex); - uniformEntry.memoryLayout = std::move(memoryLayout); - uniformEntry.name = std::move(name); - uniformEntry.type = std::move(type); - } -} diff --git a/src/Nazara/Shader/ShaderAstCloner.cpp b/src/Nazara/Shader/ShaderAstCloner.cpp index 1118029df..2948cd753 100644 --- a/src/Nazara/Shader/ShaderAstCloner.cpp +++ b/src/Nazara/Shader/ShaderAstCloner.cpp @@ -6,240 +6,257 @@ #include #include -namespace Nz +namespace Nz::ShaderAst { - ShaderNodes::StatementPtr ShaderAstCloner::Clone(const ShaderNodes::StatementPtr& statement) + ExpressionPtr AstCloner::Clone(ExpressionPtr& expr) { - ShaderAstVisitor::Visit(statement); + expr->Visit(*this); - if (!m_expressionStack.empty() || !m_variableStack.empty() || m_statementStack.size() != 1) - throw std::runtime_error("An error occurred during clone"); + assert(m_statementStack.empty() && m_expressionStack.size() == 1); + return PopExpression(); + } + StatementPtr AstCloner::Clone(StatementPtr& statement) + { + statement->Visit(*this); + + assert(m_expressionStack.empty() && m_statementStack.size() == 1); return PopStatement(); } - ShaderNodes::ExpressionPtr ShaderAstCloner::CloneExpression(const ShaderNodes::ExpressionPtr& expr) + ExpressionPtr AstCloner::CloneExpression(ExpressionPtr& expr) { if (!expr) return nullptr; - ShaderAstVisitor::Visit(expr); + expr->Visit(*this); return PopExpression(); } - ShaderNodes::StatementPtr ShaderAstCloner::CloneStatement(const ShaderNodes::StatementPtr& statement) + StatementPtr AstCloner::CloneStatement(StatementPtr& statement) { if (!statement) return nullptr; - ShaderAstVisitor::Visit(statement); + statement->Visit(*this); return PopStatement(); } - ShaderNodes::VariablePtr ShaderAstCloner::CloneVariable(const ShaderNodes::VariablePtr& variable) + void AstCloner::Visit(AccessMemberExpression& node) { - if (!variable) - return nullptr; + auto clone = std::make_unique(); + clone->memberIdentifiers = node.memberIdentifiers; + clone->structExpr = CloneExpression(node.structExpr); - ShaderVarVisitor::Visit(variable); - return PopVariable(); + PushExpression(std::move(clone)); } - void ShaderAstCloner::Visit(ShaderNodes::AccessMember& node) + void AstCloner::Visit(AssignExpression& node) { - PushExpression(ShaderNodes::AccessMember::Build(CloneExpression(node.structExpr), node.memberIndices, node.exprType)); + auto clone = std::make_unique(); + clone->op = node.op; + clone->left = CloneExpression(node.left); + clone->right = CloneExpression(node.right); + + PushExpression(std::move(clone)); } - void ShaderAstCloner::Visit(ShaderNodes::AssignOp& node) + void AstCloner::Visit(BinaryExpression& node) { - PushExpression(ShaderNodes::AssignOp::Build(node.op, CloneExpression(node.left), CloneExpression(node.right))); + auto clone = std::make_unique(); + clone->op = node.op; + clone->left = CloneExpression(node.left); + clone->right = CloneExpression(node.right); + + PushExpression(std::move(clone)); } - void ShaderAstCloner::Visit(ShaderNodes::BinaryOp& node) + void AstCloner::Visit(CastExpression& node) { - PushExpression(ShaderNodes::BinaryOp::Build(node.op, CloneExpression(node.left), CloneExpression(node.right))); - } + auto clone = std::make_unique(); + clone->targetType = node.targetType; - void ShaderAstCloner::Visit(ShaderNodes::Branch& node) - { - std::vector condStatements; - condStatements.reserve(node.condStatements.size()); - - for (auto& cond : node.condStatements) - { - auto& condStatement = condStatements.emplace_back(); - condStatement.condition = CloneExpression(cond.condition); - condStatement.statement = CloneStatement(cond.statement); - } - - PushStatement(ShaderNodes::Branch::Build(std::move(condStatements), CloneStatement(node.elseStatement))); - } - - void ShaderAstCloner::Visit(ShaderNodes::Cast& node) - { std::size_t expressionCount = 0; - std::array expressions; for (auto& expr : node.expressions) { if (!expr) break; - expressions[expressionCount] = CloneExpression(expr); - expressionCount++; + clone->expressions[expressionCount++] = CloneExpression(expr); } - PushExpression(ShaderNodes::Cast::Build(node.exprType, expressions.data(), expressionCount)); + PushExpression(std::move(clone)); } - void ShaderAstCloner::Visit(ShaderNodes::ConditionalExpression& node) + void AstCloner::Visit(ConditionalExpression& node) { - PushExpression(ShaderNodes::ConditionalExpression::Build(node.conditionName, CloneExpression(node.truePath), CloneExpression(node.falsePath))); + auto clone = std::make_unique(); + clone->conditionName = node.conditionName; + clone->falsePath = CloneExpression(node.falsePath); + clone->truePath = CloneExpression(node.truePath); + + PushExpression(std::move(clone)); } - void ShaderAstCloner::Visit(ShaderNodes::ConditionalStatement& node) + void AstCloner::Visit(ConstantExpression& node) { - PushStatement(ShaderNodes::ConditionalStatement::Build(node.conditionName, CloneStatement(node.statement))); + auto clone = std::make_unique(); + clone->value = node.value; + + PushExpression(std::move(clone)); } - void ShaderAstCloner::Visit(ShaderNodes::Constant& node) + void AstCloner::Visit(IdentifierExpression& node) { - PushExpression(ShaderNodes::Constant::Build(node.value)); + auto clone = std::make_unique(); + clone->identifier = node.identifier; + + PushExpression(std::move(clone)); } - void ShaderAstCloner::Visit(ShaderNodes::DeclareVariable& node) + void AstCloner::Visit(IntrinsicExpression& node) { - PushStatement(ShaderNodes::DeclareVariable::Build(CloneVariable(node.variable), CloneExpression(node.expression))); - } - - void ShaderAstCloner::Visit(ShaderNodes::Discard& /*node*/) - { - PushStatement(ShaderNodes::Discard::Build()); - } - - void ShaderAstCloner::Visit(ShaderNodes::ExpressionStatement& node) - { - PushStatement(ShaderNodes::ExpressionStatement::Build(CloneExpression(node.expression))); - } - - void ShaderAstCloner::Visit(ShaderNodes::Identifier& node) - { - PushExpression(ShaderNodes::Identifier::Build(CloneVariable(node.var))); - } - - void ShaderAstCloner::Visit(ShaderNodes::IntrinsicCall& node) - { - std::vector parameters; - parameters.reserve(node.parameters.size()); + auto clone = std::make_unique(); + clone->intrinsic = node.intrinsic; + clone->parameters.reserve(node.parameters.size()); for (auto& parameter : node.parameters) - parameters.push_back(CloneExpression(parameter)); + clone->parameters.push_back(CloneExpression(parameter)); - PushExpression(ShaderNodes::IntrinsicCall::Build(node.intrinsic, std::move(parameters))); + PushExpression(std::move(clone)); } - void ShaderAstCloner::Visit(ShaderNodes::NoOp& /*node*/) + void AstCloner::Visit(SwizzleExpression& node) { - PushStatement(ShaderNodes::NoOp::Build()); + auto clone = std::make_unique(); + clone->componentCount = node.componentCount; + clone->components = node.components; + clone->expression = CloneExpression(node.expression); + + PushExpression(std::move(clone)); } - void ShaderAstCloner::Visit(ShaderNodes::ReturnStatement& node) + void AstCloner::Visit(BranchStatement& node) { - PushStatement(ShaderNodes::ReturnStatement::Build(CloneExpression(node.returnExpr))); + auto clone = std::make_unique(); + clone->condStatements.reserve(node.condStatements.size()); + + for (auto& cond : node.condStatements) + { + auto& condStatement = clone->condStatements.emplace_back(); + condStatement.condition = CloneExpression(cond.condition); + condStatement.statement = CloneStatement(cond.statement); + } + + clone->elseStatement = CloneStatement(node.elseStatement); + + PushStatement(std::move(clone)); } - void ShaderAstCloner::Visit(ShaderNodes::Sample2D& node) + void AstCloner::Visit(ConditionalStatement& node) { - PushExpression(ShaderNodes::Sample2D::Build(CloneExpression(node.sampler), CloneExpression(node.coordinates))); + auto clone = std::make_unique(); + clone->conditionName = node.conditionName; + clone->statement = CloneStatement(node.statement); + + PushStatement(std::move(clone)); } - void ShaderAstCloner::Visit(ShaderNodes::StatementBlock& node) + void AstCloner::Visit(DeclareFunctionStatement& node) { - std::vector statements; - statements.reserve(node.statements.size()); + auto clone = std::make_unique(); + clone->name = node.name; + clone->parameters = node.parameters; + clone->returnType = node.returnType; + clone->statements.reserve(node.statements.size()); for (auto& statement : node.statements) - statements.push_back(CloneStatement(statement)); + clone->statements.push_back(CloneStatement(statement)); - PushStatement(ShaderNodes::StatementBlock::Build(std::move(statements))); + PushStatement(std::move(clone)); } - void ShaderAstCloner::Visit(ShaderNodes::SwizzleOp& node) + void AstCloner::Visit(DeclareStructStatement& node) { - PushExpression(ShaderNodes::SwizzleOp::Build(CloneExpression(node.expression), node.components.data(), node.componentCount)); + auto clone = std::make_unique(); + clone->description = node.description; + + PushStatement(std::move(clone)); } - void ShaderAstCloner::Visit(ShaderNodes::BuiltinVariable& var) + void AstCloner::Visit(DeclareVariableStatement& node) { - PushVariable(ShaderNodes::BuiltinVariable::Build(var.entry, var.type)); + auto clone = std::make_unique(); + clone->varName = node.varName; + clone->varType = node.varType; + clone->initialExpression = CloneExpression(node.initialExpression); + + PushStatement(std::move(clone)); } - void ShaderAstCloner::Visit(ShaderNodes::InputVariable& var) + void AstCloner::Visit(DiscardStatement& /*node*/) { - PushVariable(ShaderNodes::InputVariable::Build(var.name, var.type)); + PushStatement(std::make_unique()); } - void ShaderAstCloner::Visit(ShaderNodes::LocalVariable& var) + void AstCloner::Visit(ExpressionStatement& node) { - PushVariable(ShaderNodes::LocalVariable::Build(var.name, var.type)); + auto clone = std::make_unique(); + clone->expression = CloneExpression(node.expression); + + PushStatement(std::move(clone)); } - void ShaderAstCloner::Visit(ShaderNodes::OutputVariable& var) + void AstCloner::Visit(MultiStatement& node) { - PushVariable(ShaderNodes::OutputVariable::Build(var.name, var.type)); + auto clone = std::make_unique(); + clone->statements.reserve(node.statements.size()); + for (auto& statement : node.statements) + clone->statements.push_back(CloneStatement(statement)); + + PushStatement(std::move(clone)); } - void ShaderAstCloner::Visit(ShaderNodes::ParameterVariable& var) + void AstCloner::Visit(NoOpStatement& /*node*/) { - PushVariable(ShaderNodes::ParameterVariable::Build(var.name, var.type)); + PushStatement(std::make_unique()); } - void ShaderAstCloner::Visit(ShaderNodes::UniformVariable& var) + void AstCloner::Visit(ReturnStatement& node) { - PushVariable(ShaderNodes::UniformVariable::Build(var.name, var.type)); + auto clone = std::make_unique(); + clone->returnExpr = CloneExpression(node.returnExpr); + + PushStatement(std::move(clone)); } - void ShaderAstCloner::PushExpression(ShaderNodes::ExpressionPtr expression) + void AstCloner::PushExpression(ExpressionPtr expression) { m_expressionStack.emplace_back(std::move(expression)); } - void ShaderAstCloner::PushStatement(ShaderNodes::StatementPtr statement) + void AstCloner::PushStatement(StatementPtr statement) { m_statementStack.emplace_back(std::move(statement)); } - void ShaderAstCloner::PushVariable(ShaderNodes::VariablePtr variable) - { - m_variableStack.emplace_back(std::move(variable)); - } - - ShaderNodes::ExpressionPtr ShaderAstCloner::PopExpression() + ExpressionPtr AstCloner::PopExpression() { assert(!m_expressionStack.empty()); - ShaderNodes::ExpressionPtr expr = std::move(m_expressionStack.back()); + ExpressionPtr expr = std::move(m_expressionStack.back()); m_expressionStack.pop_back(); return expr; } - ShaderNodes::StatementPtr ShaderAstCloner::PopStatement() + StatementPtr AstCloner::PopStatement() { assert(!m_statementStack.empty()); - ShaderNodes::StatementPtr expr = std::move(m_statementStack.back()); + StatementPtr expr = std::move(m_statementStack.back()); m_statementStack.pop_back(); return expr; } - - ShaderNodes::VariablePtr ShaderAstCloner::PopVariable() - { - assert(!m_variableStack.empty()); - - ShaderNodes::VariablePtr var = std::move(m_variableStack.back()); - m_variableStack.pop_back(); - - return var; - } } diff --git a/src/Nazara/Shader/ShaderAstExpressionType.cpp b/src/Nazara/Shader/ShaderAstExpressionType.cpp new file mode 100644 index 000000000..be4f238e9 --- /dev/null +++ b/src/Nazara/Shader/ShaderAstExpressionType.cpp @@ -0,0 +1,198 @@ +// Copyright (C) 2020 Jérôme Leclercq +// This file is part of the "Nazara Engine - Shader generator" +// For conditions of distribution and use, see copyright notice in Config.hpp + +#include +#include +#include +#include + +namespace Nz::ShaderAst +{ + ShaderExpressionType ExpressionTypeVisitor::GetExpressionType(Expression& expression, AstCache* cache = nullptr) + { + m_cache = cache; + ShaderExpressionType type = GetExpressionTypeInternal(expression); + m_cache = nullptr; + + return type; + } + + ShaderExpressionType ExpressionTypeVisitor::GetExpressionTypeInternal(Expression& expression) + { + m_lastExpressionType.reset(); + + Visit(expression); + + assert(m_lastExpressionType.has_value()); + return std::move(*m_lastExpressionType); + } + + void ExpressionTypeVisitor::Visit(Expression& expression) + { + if (m_cache) + { + auto it = m_cache->nodeExpressionType.find(&expression); + if (it != m_cache->nodeExpressionType.end()) + { + m_lastExpressionType = it->second; + return; + } + } + + expression.Visit(*this); + + if (m_cache) + { + assert(m_lastExpressionType.has_value()); + m_cache->nodeExpressionType.emplace(&expression, *m_lastExpressionType); + } + } + + void ExpressionTypeVisitor::Visit(AccessMemberExpression& node) + { + throw std::runtime_error("unhandled accessmember expression"); + } + + void ExpressionTypeVisitor::Visit(AssignExpression& node) + { + Visit(*node.left); + } + + void ExpressionTypeVisitor::Visit(BinaryExpression& node) + { + switch (node.op) + { + case BinaryType::Add: + case BinaryType::Subtract: + return Visit(*node.left); + + case BinaryType::Divide: + case BinaryType::Multiply: + { + ShaderExpressionType leftExprType = GetExpressionTypeInternal(*node.left); + assert(IsBasicType(leftExprType)); + + ShaderExpressionType rightExprType = GetExpressionTypeInternal(*node.right); + assert(IsBasicType(rightExprType)); + + switch (std::get(leftExprType)) + { + case BasicType::Boolean: + case BasicType::Float2: + case BasicType::Float3: + case BasicType::Float4: + case BasicType::Int2: + case BasicType::Int3: + case BasicType::Int4: + case BasicType::UInt2: + case BasicType::UInt3: + case BasicType::UInt4: + m_lastExpressionType = std::move(leftExprType); + break; + + case BasicType::Float1: + case BasicType::Int1: + case BasicType::Mat4x4: + case BasicType::UInt1: + m_lastExpressionType = std::move(rightExprType); + break; + + case BasicType::Sampler2D: + case BasicType::Void: + break; + } + + break; + } + + case BinaryType::CompEq: + case BinaryType::CompGe: + case BinaryType::CompGt: + case BinaryType::CompLe: + case BinaryType::CompLt: + case BinaryType::CompNe: + m_lastExpressionType = BasicType::Boolean; + break; + } + } + + void ExpressionTypeVisitor::Visit(CastExpression& node) + { + m_lastExpressionType = node.targetType; + } + + void ExpressionTypeVisitor::Visit(ConditionalExpression& node) + { + ShaderExpressionType leftExprType = GetExpressionTypeInternal(*node.truePath); + assert(leftExprType == GetExpressionTypeInternal(*node.falsePath)); + + m_lastExpressionType = std::move(leftExprType); + } + + void ExpressionTypeVisitor::Visit(ConstantExpression& node) + { + m_lastExpressionType = std::visit([&](auto&& arg) + { + using T = std::decay_t; + + if constexpr (std::is_same_v) + return BasicType::Boolean; + else if constexpr (std::is_same_v) + return BasicType::Float1; + else if constexpr (std::is_same_v) + return BasicType::Int1; + else if constexpr (std::is_same_v) + return BasicType::Int1; + else if constexpr (std::is_same_v) + return BasicType::Float2; + else if constexpr (std::is_same_v) + return BasicType::Float3; + else if constexpr (std::is_same_v) + return BasicType::Float4; + else if constexpr (std::is_same_v) + return BasicType::Int2; + else if constexpr (std::is_same_v) + return BasicType::Int3; + else if constexpr (std::is_same_v) + return BasicType::Int4; + else + static_assert(AlwaysFalse::value, "non-exhaustive visitor"); + }, node.value); + } + + void ExpressionTypeVisitor::Visit(IdentifierExpression& node) + { + auto scopeIt = m_cache->scopeIdByNode.find(&node); + if (scopeIt == m_cache->scopeIdByNode.end()) + throw std::runtime_error("internal error"); + + const AstCache::Identifier* identifier = m_cache->FindIdentifier(scopeIt->second, node.identifier); + if (!identifier || !std::holds_alternative(identifier->value)) + throw std::runtime_error("internal error"); + + m_lastExpressionType = std::get(identifier->value).type; + } + + void ExpressionTypeVisitor::Visit(IntrinsicExpression& node) + { + switch (node.intrinsic) + { + case IntrinsicType::CrossProduct: + Visit(*node.parameters.front()); + break; + + case IntrinsicType::DotProduct: + m_lastExpressionType = BasicType::Float1; + break; + } + } + + void ExpressionTypeVisitor::Visit(SwizzleExpression& node) + { + const ShaderExpressionType& exprType = GetExpressionTypeInternal(*node.expression); + assert(IsBasicType(exprType)); + + m_lastExpressionType = static_cast(UnderlyingCast(GetComponentType(std::get(exprType))) + node.componentCount - 1); + } +} diff --git a/src/Nazara/Shader/ShaderAstVisitor.cpp b/src/Nazara/Shader/ShaderAstExpressionVisitor.cpp similarity index 52% rename from src/Nazara/Shader/ShaderAstVisitor.cpp rename to src/Nazara/Shader/ShaderAstExpressionVisitor.cpp index 719e3ad99..44beed045 100644 --- a/src/Nazara/Shader/ShaderAstVisitor.cpp +++ b/src/Nazara/Shader/ShaderAstExpressionVisitor.cpp @@ -2,15 +2,10 @@ // This file is part of the "Nazara Engine - Shader generator" // For conditions of distribution and use, see copyright notice in Config.hpp -#include +#include #include -namespace Nz +namespace Nz::ShaderAst { - ShaderAstVisitor::~ShaderAstVisitor() = default; - - void ShaderAstVisitor::Visit(const ShaderNodes::NodePtr& node) - { - node->Visit(*this); - } + AstExpressionVisitor::~AstExpressionVisitor() = default; } diff --git a/src/Nazara/Shader/ShaderAstExpressionVisitorExcept.cpp b/src/Nazara/Shader/ShaderAstExpressionVisitorExcept.cpp new file mode 100644 index 000000000..1fddc0e78 --- /dev/null +++ b/src/Nazara/Shader/ShaderAstExpressionVisitorExcept.cpp @@ -0,0 +1,15 @@ +// Copyright (C) 2020 Jérôme Leclercq +// This file is part of the "Nazara Engine - Shader generator" +// For conditions of distribution and use, see copyright notice in Config.hpp + +#include +#include + +namespace Nz::ShaderAst +{ +#define NAZARA_SHADERAST_EXPRESSION(Node) void ExpressionVisitorExcept::Visit(ShaderAst::Node& /*node*/) \ + { \ + throw std::runtime_error("unexpected " #Node " node"); \ + } +#include +} diff --git a/src/Nazara/Shader/ShaderAstOptimizer.cpp b/src/Nazara/Shader/ShaderAstOptimizer.cpp index b2e5acf96..33c8d3f76 100644 --- a/src/Nazara/Shader/ShaderAstOptimizer.cpp +++ b/src/Nazara/Shader/ShaderAstOptimizer.cpp @@ -3,16 +3,22 @@ // For conditions of distribution and use, see copyright notice in Config.hpp #include -#include #include +#include #include #include #include -namespace Nz +namespace Nz::ShaderAst { namespace { + template + std::unique_ptr static_unique_pointer_cast(std::unique_ptr&& ptr) + { + return std::unique_ptr(static_cast(ptr.release())); + } + template struct is_complete_helper { @@ -29,14 +35,14 @@ namespace Nz inline constexpr bool is_complete_v = is_complete::value; - template + template struct PropagateConstantType; // CompEq template struct CompEqBase { - ShaderNodes::ExpressionPtr operator()(const T1& lhs, const T2& rhs) + ExpressionPtr operator()(const T1& lhs, const T2& rhs) { return ShaderBuilder::Constant(lhs == rhs); } @@ -46,7 +52,7 @@ namespace Nz struct CompEq; template - struct PropagateConstantType + struct PropagateConstantType { using Op = typename CompEq; }; @@ -55,7 +61,7 @@ namespace Nz template struct CompGeBase { - ShaderNodes::ExpressionPtr operator()(const T1& lhs, const T2& rhs) + ExpressionPtr operator()(const T1& lhs, const T2& rhs) { return ShaderBuilder::Constant(lhs >= rhs); } @@ -65,7 +71,7 @@ namespace Nz struct CompGe; template - struct PropagateConstantType + struct PropagateConstantType { using Op = typename CompGe; }; @@ -74,7 +80,7 @@ namespace Nz template struct CompGtBase { - ShaderNodes::ExpressionPtr operator()(const T1& lhs, const T2& rhs) + ExpressionPtr operator()(const T1& lhs, const T2& rhs) { return ShaderBuilder::Constant(lhs > rhs); } @@ -84,7 +90,7 @@ namespace Nz struct CompGt; template - struct PropagateConstantType + struct PropagateConstantType { using Op = typename CompGt; }; @@ -93,7 +99,7 @@ namespace Nz template struct CompLeBase { - ShaderNodes::ExpressionPtr operator()(const T1& lhs, const T2& rhs) + ExpressionPtr operator()(const T1& lhs, const T2& rhs) { return ShaderBuilder::Constant(lhs <= rhs); } @@ -103,7 +109,7 @@ namespace Nz struct CompLe; template - struct PropagateConstantType + struct PropagateConstantType { using Op = typename CompLe; }; @@ -112,7 +118,7 @@ namespace Nz template struct CompLtBase { - ShaderNodes::ExpressionPtr operator()(const T1& lhs, const T2& rhs) + ExpressionPtr operator()(const T1& lhs, const T2& rhs) { return ShaderBuilder::Constant(lhs < rhs); } @@ -122,7 +128,7 @@ namespace Nz struct CompLt; template - struct PropagateConstantType + struct PropagateConstantType { using Op = typename CompLe; }; @@ -131,7 +137,7 @@ namespace Nz template struct CompNeBase { - ShaderNodes::ExpressionPtr operator()(const T1& lhs, const T2& rhs) + ExpressionPtr operator()(const T1& lhs, const T2& rhs) { return ShaderBuilder::Constant(lhs != rhs); } @@ -141,7 +147,7 @@ namespace Nz struct CompNe; template - struct PropagateConstantType + struct PropagateConstantType { using Op = typename CompNe; }; @@ -150,7 +156,7 @@ namespace Nz template struct AdditionBase { - ShaderNodes::ExpressionPtr operator()(const T1& lhs, const T2& rhs) + ExpressionPtr operator()(const T1& lhs, const T2& rhs) { return ShaderBuilder::Constant(lhs + rhs); } @@ -160,7 +166,7 @@ namespace Nz struct Addition; template - struct PropagateConstantType + struct PropagateConstantType { using Op = typename Addition; }; @@ -169,7 +175,7 @@ namespace Nz template struct DivisionBase { - ShaderNodes::ExpressionPtr operator()(const T1& lhs, const T2& rhs) + ExpressionPtr operator()(const T1& lhs, const T2& rhs) { return ShaderBuilder::Constant(lhs / rhs); } @@ -179,7 +185,7 @@ namespace Nz struct Division; template - struct PropagateConstantType + struct PropagateConstantType { using Op = typename Division; }; @@ -188,7 +194,7 @@ namespace Nz template struct MultiplicationBase { - ShaderNodes::ExpressionPtr operator()(const T1& lhs, const T2& rhs) + ExpressionPtr operator()(const T1& lhs, const T2& rhs) { return ShaderBuilder::Constant(lhs * rhs); } @@ -198,7 +204,7 @@ namespace Nz struct Multiplication; template - struct PropagateConstantType + struct PropagateConstantType { using Op = typename Multiplication; }; @@ -207,7 +213,7 @@ namespace Nz template struct SubtractionBase { - ShaderNodes::ExpressionPtr operator()(const T1& lhs, const T2& rhs) + ExpressionPtr operator()(const T1& lhs, const T2& rhs) { return ShaderBuilder::Constant(lhs - rhs); } @@ -217,7 +223,7 @@ namespace Nz struct Subtraction; template - struct PropagateConstantType + struct PropagateConstantType { using Op = typename Subtraction; }; @@ -375,92 +381,89 @@ namespace Nz #undef EnableOptimisation } - ShaderNodes::StatementPtr ShaderAstOptimizer::Optimise(const ShaderNodes::StatementPtr& statement) + StatementPtr AstOptimizer::Optimise(StatementPtr& statement) { - m_shaderAst = nullptr; - return CloneStatement(statement); } - ShaderNodes::StatementPtr ShaderAstOptimizer::Optimise(const ShaderNodes::StatementPtr& statement, const ShaderAst& shader, UInt64 enabledConditions) + StatementPtr AstOptimizer::Optimise(StatementPtr& statement, UInt64 enabledConditions) { - m_shaderAst = &shader; m_enabledConditions = enabledConditions; return CloneStatement(statement); } - void ShaderAstOptimizer::Visit(ShaderNodes::BinaryOp& node) + void AstOptimizer::Visit(BinaryExpression& node) { auto lhs = CloneExpression(node.left); auto rhs = CloneExpression(node.right); - if (lhs->GetType() == ShaderNodes::NodeType::Constant && rhs->GetType() == ShaderNodes::NodeType::Constant) + if (lhs->GetType() == NodeType::ConstantExpression && rhs->GetType() == NodeType::ConstantExpression) { - auto lhsConstant = std::static_pointer_cast(lhs); - auto rhsConstant = std::static_pointer_cast(rhs); + auto lhsConstant = static_unique_pointer_cast(std::move(lhs)); + auto rhsConstant = static_unique_pointer_cast(std::move(rhs)); switch (node.op) { - case ShaderNodes::BinaryType::Add: - return PropagateConstant(lhsConstant, rhsConstant); + case BinaryType::Add: + return PropagateConstant(std::move(lhsConstant), std::move(rhsConstant)); - case ShaderNodes::BinaryType::Subtract: - return PropagateConstant(lhsConstant, rhsConstant); + case BinaryType::Subtract: + return PropagateConstant(std::move(lhsConstant), std::move(rhsConstant)); - case ShaderNodes::BinaryType::Multiply: - return PropagateConstant(lhsConstant, rhsConstant); + case BinaryType::Multiply: + return PropagateConstant(std::move(lhsConstant), std::move(rhsConstant)); - case ShaderNodes::BinaryType::Divide: - return PropagateConstant(lhsConstant, rhsConstant); + case BinaryType::Divide: + return PropagateConstant(std::move(lhsConstant), std::move(rhsConstant)); - case ShaderNodes::BinaryType::CompEq: - return PropagateConstant(lhsConstant, rhsConstant); + case BinaryType::CompEq: + return PropagateConstant(std::move(lhsConstant), std::move(rhsConstant)); - case ShaderNodes::BinaryType::CompGe: - return PropagateConstant(lhsConstant, rhsConstant); + case BinaryType::CompGe: + return PropagateConstant(std::move(lhsConstant), std::move(rhsConstant)); - case ShaderNodes::BinaryType::CompGt: - return PropagateConstant(lhsConstant, rhsConstant); + case BinaryType::CompGt: + return PropagateConstant(std::move(lhsConstant), std::move(rhsConstant)); - case ShaderNodes::BinaryType::CompLe: - return PropagateConstant(lhsConstant, rhsConstant); + case BinaryType::CompLe: + return PropagateConstant(std::move(lhsConstant), std::move(rhsConstant)); - case ShaderNodes::BinaryType::CompLt: - return PropagateConstant(lhsConstant, rhsConstant); + case BinaryType::CompLt: + return PropagateConstant(std::move(lhsConstant), std::move(rhsConstant)); - case ShaderNodes::BinaryType::CompNe: - return PropagateConstant(lhsConstant, rhsConstant); + case BinaryType::CompNe: + return PropagateConstant(std::move(lhsConstant), std::move(rhsConstant)); } } - ShaderAstCloner::Visit(node); + AstCloner::Visit(node); } - void ShaderAstOptimizer::Visit(ShaderNodes::Branch& node) + void AstOptimizer::Visit(BranchStatement& node) { - std::vector statements; - ShaderNodes::StatementPtr elseStatement; + std::vector statements; + StatementPtr elseStatement; for (auto& condStatement : node.condStatements) { auto cond = CloneExpression(condStatement.condition); - if (cond->GetType() == ShaderNodes::NodeType::Constant) + if (cond->GetType() == NodeType::ConstantExpression) { - auto constant = std::static_pointer_cast(cond); + auto& constant = static_cast(*cond); - assert(IsBasicType(cond->GetExpressionType())); - assert(std::get(cond->GetExpressionType()) == ShaderNodes::BasicType::Boolean); + assert(IsBasicType(GetExpressionType(constant))); + assert(std::get(GetExpressionType(constant)) == BasicType::Boolean); - bool cValue = std::get(constant->value); + bool cValue = std::get(constant.value); if (!cValue) continue; if (statements.empty()) { // First condition is true, dismiss the branch - Visit(condStatement.statement); + condStatement.statement->Visit(*this); return; } else @@ -482,47 +485,54 @@ namespace Nz { // All conditions have been removed, replace by else statement or no-op if (node.elseStatement) - return Visit(node.elseStatement); + { + node.elseStatement->Visit(*this); + return; + } else - return PushStatement(ShaderNodes::NoOp::Build()); + return PushStatement(ShaderBuilder::NoOp()); } if (!elseStatement) elseStatement = CloneStatement(node.elseStatement); - PushStatement(ShaderNodes::Branch::Build(std::move(statements), std::move(elseStatement))); + PushStatement(ShaderBuilder::Branch(std::move(statements), std::move(elseStatement))); } - void ShaderAstOptimizer::Visit(ShaderNodes::ConditionalExpression& node) + void AstOptimizer::Visit(ConditionalExpression& node) { - if (!m_shaderAst) + return AstCloner::Visit(node); + + /*if (!m_shaderAst) return ShaderAstCloner::Visit(node); std::size_t conditionIndex = m_shaderAst->FindConditionByName(node.conditionName); - assert(conditionIndex != ShaderAst::InvalidCondition); + assert(conditionIndex != InvalidCondition); if (TestBit(m_enabledConditions, conditionIndex)) Visit(node.truePath); else - Visit(node.falsePath); + Visit(node.falsePath);*/ } - void ShaderAstOptimizer::Visit(ShaderNodes::ConditionalStatement& node) + void AstOptimizer::Visit(ConditionalStatement& node) { - if (!m_shaderAst) + return AstCloner::Visit(node); + + /*if (!m_shaderAst) return ShaderAstCloner::Visit(node); std::size_t conditionIndex = m_shaderAst->FindConditionByName(node.conditionName); - assert(conditionIndex != ShaderAst::InvalidCondition); + assert(conditionIndex != InvalidCondition); if (TestBit(m_enabledConditions, conditionIndex)) - Visit(node.statement); + Visit(node.statement);*/ } - template - void ShaderAstOptimizer::PropagateConstant(const std::shared_ptr& lhs, const std::shared_ptr& rhs) + template + void AstOptimizer::PropagateConstant(std::unique_ptr&& lhs, std::unique_ptr&& rhs) { - ShaderNodes::ExpressionPtr optimized; + ExpressionPtr optimized; std::visit([&](auto&& arg1) { using T1 = std::decay_t; @@ -543,8 +553,8 @@ namespace Nz }, lhs->value); if (optimized) - PushExpression(optimized); + PushExpression(std::move(optimized)); else - PushExpression(ShaderNodes::BinaryOp::Build(Type, lhs, rhs)); + PushExpression(ShaderBuilder::Binary(Type, std::move(lhs), std::move(rhs))); } } diff --git a/src/Nazara/Shader/ShaderAstRecursiveVisitor.cpp b/src/Nazara/Shader/ShaderAstRecursiveVisitor.cpp index f4d858d91..c0ef0d920 100644 --- a/src/Nazara/Shader/ShaderAstRecursiveVisitor.cpp +++ b/src/Nazara/Shader/ShaderAstRecursiveVisitor.cpp @@ -5,116 +5,121 @@ #include #include -namespace Nz +namespace Nz::ShaderAst { - void ShaderAstRecursiveVisitor::Visit(ShaderNodes::AccessMember& node) + void AstRecursiveVisitor::Visit(AccessMemberExpression& node) { - Visit(node.structExpr); + node.structExpr->Visit(*this); } - void ShaderAstRecursiveVisitor::Visit(ShaderNodes::AssignOp& node) + void AstRecursiveVisitor::Visit(AssignExpression& node) { - Visit(node.left); - Visit(node.right); + node.left->Visit(*this); + node.right->Visit(*this); } - void ShaderAstRecursiveVisitor::Visit(ShaderNodes::BinaryOp& node) + void AstRecursiveVisitor::Visit(BinaryExpression& node) { - Visit(node.left); - Visit(node.right); + node.left->Visit(*this); + node.right->Visit(*this); } - void ShaderAstRecursiveVisitor::Visit(ShaderNodes::Branch& node) - { - for (auto& cond : node.condStatements) - { - Visit(cond.condition); - Visit(cond.statement); - } - - if (node.elseStatement) - Visit(node.elseStatement); - } - - void ShaderAstRecursiveVisitor::Visit(ShaderNodes::Cast& node) + void AstRecursiveVisitor::Visit(CastExpression& node) { for (auto& expr : node.expressions) { if (!expr) break; - Visit(expr); + expr->Visit(*this); } } - void ShaderAstRecursiveVisitor::Visit(ShaderNodes::ConditionalExpression& node) + void AstRecursiveVisitor::Visit(ConditionalExpression& node) { - Visit(node.truePath); - Visit(node.falsePath); + node.truePath->Visit(*this); + node.falsePath->Visit(*this); } - void ShaderAstRecursiveVisitor::Visit(ShaderNodes::ConditionalStatement& node) - { - Visit(node.statement); - } - - void ShaderAstRecursiveVisitor::Visit(ShaderNodes::Constant& /*node*/) + void AstRecursiveVisitor::Visit(ConstantExpression& /*node*/) { /* Nothing to do */ } - void ShaderAstRecursiveVisitor::Visit(ShaderNodes::DeclareVariable& node) - { - if (node.expression) - Visit(node.expression); - } - - void ShaderAstRecursiveVisitor::Visit(ShaderNodes::Discard& /*node*/) + void AstRecursiveVisitor::Visit(IdentifierExpression& /*node*/) { /* Nothing to do */ } - void ShaderAstRecursiveVisitor::Visit(ShaderNodes::ExpressionStatement& node) - { - Visit(node.expression); - } - - void ShaderAstRecursiveVisitor::Visit(ShaderNodes::Identifier& /*node*/) - { - /* Nothing to do */ - } - - void ShaderAstRecursiveVisitor::Visit(ShaderNodes::IntrinsicCall& node) + void AstRecursiveVisitor::Visit(IntrinsicExpression& node) { for (auto& param : node.parameters) - Visit(param); + param->Visit(*this); } - void ShaderAstRecursiveVisitor::Visit(ShaderNodes::NoOp& /*node*/) + void AstRecursiveVisitor::Visit(SwizzleExpression& node) + { + node.expression->Visit(*this); + } + + void AstRecursiveVisitor::Visit(BranchStatement& node) + { + for (auto& cond : node.condStatements) + { + cond.condition->Visit(*this); + cond.statement->Visit(*this); + } + + if (node.elseStatement) + node.elseStatement->Visit(*this); + } + + void AstRecursiveVisitor::Visit(ConditionalStatement& node) + { + node.statement->Visit(*this); + } + + void AstRecursiveVisitor::Visit(DeclareFunctionStatement& node) + { + for (auto& statement : node.statements) + statement->Visit(*this); + } + + void AstRecursiveVisitor::Visit(DeclareStructStatement& /*node*/) { /* Nothing to do */ } - void ShaderAstRecursiveVisitor::Visit(ShaderNodes::ReturnStatement& node) + void AstRecursiveVisitor::Visit(DeclareVariableStatement& node) { - if (node.returnExpr) - Visit(node.returnExpr); + if (node.initialExpression) + node.initialExpression->Visit(*this); } - void ShaderAstRecursiveVisitor::Visit(ShaderNodes::Sample2D& node) + void AstRecursiveVisitor::Visit(DiscardStatement& /*node*/) { - Visit(node.sampler); - Visit(node.coordinates); + /* Nothing to do */ } - void ShaderAstRecursiveVisitor::Visit(ShaderNodes::StatementBlock& node) + void AstRecursiveVisitor::Visit(ExpressionStatement& node) + { + node.expression->Visit(*this); + } + + void AstRecursiveVisitor::Visit(MultiStatement& node) { for (auto& statement : node.statements) - Visit(statement); + statement->Visit(*this); } - void ShaderAstRecursiveVisitor::Visit(ShaderNodes::SwizzleOp& node) + void AstRecursiveVisitor::Visit(NoOpStatement& /*node*/) { - Visit(node.expression); + /* Nothing to do */ + } + + void AstRecursiveVisitor::Visit(ReturnStatement& node) + { + if (node.returnExpr) + node.returnExpr->Visit(*this); } } diff --git a/src/Nazara/Shader/ShaderAstSerializer.cpp b/src/Nazara/Shader/ShaderAstSerializer.cpp index 13780339b..4100a0204 100644 --- a/src/Nazara/Shader/ShaderAstSerializer.cpp +++ b/src/Nazara/Shader/ShaderAstSerializer.cpp @@ -3,221 +3,74 @@ // For conditions of distribution and use, see copyright notice in Config.hpp #include -#include -#include +#include +#include #include -namespace Nz +namespace Nz::ShaderAst { namespace { constexpr UInt32 s_magicNumber = 0x4E534852; constexpr UInt32 s_currentVersion = 1; - class ShaderSerializerVisitor : public ShaderAstVisitor, public ShaderVarVisitor + class ShaderSerializerVisitor : public AstExpressionVisitor, public AstStatementVisitor { public: - ShaderSerializerVisitor(ShaderAstSerializerBase& serializer) : + ShaderSerializerVisitor(AstSerializerBase& serializer) : m_serializer(serializer) { } - void Visit(ShaderNodes::AccessMember& node) override - { - Serialize(node); - } - - void Visit(ShaderNodes::AssignOp& node) override - { - Serialize(node); - } - - void Visit(ShaderNodes::BinaryOp& node) override - { - Serialize(node); - } - - void Visit(ShaderNodes::Branch& node) override - { - Serialize(node); - } - - void Visit(ShaderNodes::Cast& node) override - { - Serialize(node); - } - - void Visit(ShaderNodes::ConditionalExpression& node) override - { - Serialize(node); - } - - void Visit(ShaderNodes::ConditionalStatement& node) override - { - Serialize(node); - } - - void Visit(ShaderNodes::Constant& node) override - { - Serialize(node); - } - - void Visit(ShaderNodes::DeclareVariable& node) override - { - Serialize(node); - } - - void Visit(ShaderNodes::Discard& node) override - { - Serialize(node); - } - - void Visit(ShaderNodes::ExpressionStatement& node) override - { - Serialize(node); - } - - void Visit(ShaderNodes::Identifier& node) override - { - Serialize(node); - } - - void Visit(ShaderNodes::IntrinsicCall& node) override - { - Serialize(node); - } - - void Visit(ShaderNodes::NoOp& node) override - { - Serialize(node); - } - - void Visit(ShaderNodes::ReturnStatement& node) override - { - Serialize(node); - } - - void Visit(ShaderNodes::Sample2D& node) override - { - Serialize(node); - } - - void Visit(ShaderNodes::StatementBlock& node) override - { - Serialize(node); - } - - void Visit(ShaderNodes::SwizzleOp& node) override - { - Serialize(node); - } - - - void Visit(ShaderNodes::BuiltinVariable& var) override - { - Serialize(var); - } - - void Visit(ShaderNodes::InputVariable& var) override - { - Serialize(var); - } - - void Visit(ShaderNodes::LocalVariable& var) override - { - Serialize(var); - } - - void Visit(ShaderNodes::OutputVariable& var) override - { - Serialize(var); - } - - void Visit(ShaderNodes::ParameterVariable& var) override - { - Serialize(var); - } - - void Visit(ShaderNodes::UniformVariable& var) override - { - Serialize(var); +#define NAZARA_SHADERAST_NODE(Node) void Visit(Node& node) override \ + { \ + m_serializer.Serialize(node); \ } +#include private: - template - void Serialize(const T& node) - { - // I know const_cast is evil but I don't have a better solution here (it's not used to write) - m_serializer.Serialize(const_cast(node)); - } - - ShaderAstSerializerBase& m_serializer; + AstSerializerBase& m_serializer; }; } - void ShaderAstSerializerBase::Serialize(ShaderNodes::AccessMember& node) + void AstSerializerBase::Serialize(AccessMemberExpression& node) { Node(node.structExpr); - Type(node.exprType); - Container(node.memberIndices); - for (std::size_t& index : node.memberIndices) - SizeT(index); + Container(node.memberIdentifiers); + for (std::string& identifier : node.memberIdentifiers) + Value(identifier); } - void ShaderAstSerializerBase::Serialize(ShaderNodes::AssignOp& node) + void AstSerializerBase::Serialize(AssignExpression& node) { Enum(node.op); Node(node.left); Node(node.right); } - void ShaderAstSerializerBase::Serialize(ShaderNodes::BinaryOp& node) + void AstSerializerBase::Serialize(BinaryExpression& node) { Enum(node.op); Node(node.left); Node(node.right); } - void ShaderAstSerializerBase::Serialize(ShaderNodes::Branch& node) + void AstSerializerBase::Serialize(CastExpression& node) { - Container(node.condStatements); - for (auto& condStatement : node.condStatements) - { - Node(condStatement.condition); - Node(condStatement.statement); - } - - Node(node.elseStatement); - } - - void ShaderAstSerializerBase::Serialize(ShaderNodes::BuiltinVariable& node) - { - Enum(node.entry); - Type(node.type); - } - - void ShaderAstSerializerBase::Serialize(ShaderNodes::Cast& node) - { - Enum(node.exprType); + Enum(node.targetType); for (auto& expr : node.expressions) Node(expr); } - void ShaderAstSerializerBase::Serialize(ShaderNodes::ConditionalExpression& node) + void AstSerializerBase::Serialize(ConditionalExpression& node) { Value(node.conditionName); Node(node.truePath); Node(node.falsePath); } - - void ShaderAstSerializerBase::Serialize(ShaderNodes::ConditionalStatement& node) - { - Value(node.conditionName); - Node(node.statement); - } - - void ShaderAstSerializerBase::Serialize(ShaderNodes::Constant& node) + + void AstSerializerBase::Serialize(ConstantExpression& node) { UInt32 typeIndex; if (IsWriting()) @@ -251,28 +104,19 @@ namespace Nz } } - void ShaderAstSerializerBase::Serialize(ShaderNodes::DeclareVariable& node) + void AstSerializerBase::Serialize(DeclareVariableStatement& node) { - Variable(node.variable); - Node(node.expression); + Value(node.varName); + Type(node.varType); + Node(node.initialExpression); } - void ShaderAstSerializerBase::Serialize(ShaderNodes::Discard& /*node*/) + void AstSerializerBase::Serialize(IdentifierExpression& node) { - /* Nothing to do */ + Value(node.identifier); } - void ShaderAstSerializerBase::Serialize(ShaderNodes::ExpressionStatement& node) - { - Node(node.expression); - } - - void ShaderAstSerializerBase::Serialize(ShaderNodes::Identifier& node) - { - Variable(node.var); - } - - void ShaderAstSerializerBase::Serialize(ShaderNodes::IntrinsicCall& node) + void AstSerializerBase::Serialize(IntrinsicExpression& node) { Enum(node.intrinsic); Container(node.parameters); @@ -280,36 +124,7 @@ namespace Nz Node(param); } - void ShaderAstSerializerBase::Serialize(ShaderNodes::NamedVariable& node) - { - Value(node.name); - Type(node.type); - } - - void ShaderAstSerializerBase::Serialize(ShaderNodes::NoOp& /*node*/) - { - /* Nothing to do */ - } - - void ShaderAstSerializerBase::Serialize(ShaderNodes::ReturnStatement& node) - { - Node(node.returnExpr); - } - - void ShaderAstSerializerBase::Serialize(ShaderNodes::Sample2D& node) - { - Node(node.sampler); - Node(node.coordinates); - } - - void ShaderAstSerializerBase::Serialize(ShaderNodes::StatementBlock& node) - { - Container(node.statements); - for (auto& statement : node.statements) - Node(statement); - } - - void ShaderAstSerializerBase::Serialize(ShaderNodes::SwizzleOp& node) + void AstSerializerBase::Serialize(SwizzleExpression& node) { SizeT(node.componentCount); Node(node.expression); @@ -319,100 +134,85 @@ namespace Nz } - void ShaderAstSerializer::Serialize(const ShaderAst& shader) + void AstSerializerBase::Serialize(BranchStatement& node) + { + Container(node.condStatements); + for (auto& condStatement : node.condStatements) + { + Node(condStatement.condition); + Node(condStatement.statement); + } + + Node(node.elseStatement); + } + + void AstSerializerBase::Serialize(ConditionalStatement& node) + { + Value(node.conditionName); + Node(node.statement); + } + + void AstSerializerBase::Serialize(DeclareFunctionStatement& node) + { + Value(node.name); + Type(node.returnType); + + Container(node.parameters); + for (auto& parameter : node.parameters) + { + Value(parameter.name); + Type(parameter.type); + } + + Container(node.statements); + for (auto& statement : node.statements) + Node(statement); + } + + void AstSerializerBase::Serialize(DeclareStructStatement& node) + { + Value(node.description.name); + + Container(node.description.members); + for (auto& member : node.description.members) + { + Value(member.name); + Type(member.type); + } + } + + void AstSerializerBase::Serialize(DiscardStatement& /*node*/) + { + /* Nothing to do */ + } + + void AstSerializerBase::Serialize(ExpressionStatement& node) + { + Node(node.expression); + } + + void AstSerializerBase::Serialize(MultiStatement& node) + { + Container(node.statements); + for (auto& statement : node.statements) + Node(statement); + } + + void AstSerializerBase::Serialize(NoOpStatement& /*node*/) + { + /* Nothing to do */ + } + + void AstSerializerBase::Serialize(ReturnStatement& node) + { + Node(node.returnExpr); + } + + void ShaderAstSerializer::Serialize(StatementPtr& shader) { m_stream << s_magicNumber << s_currentVersion; - m_stream << UInt32(shader.GetStage()); - - auto SerializeType = [&](const ShaderExpressionType& type) - { - std::visit([&](auto&& arg) - { - using T = std::decay_t; - if constexpr (std::is_same_v) - { - m_stream << UInt8(0); - m_stream << UInt32(arg); - } - else if constexpr (std::is_same_v) - { - m_stream << UInt8(1); - m_stream << arg; - } - else - static_assert(AlwaysFalse::value, "non-exhaustive visitor"); - }, type); - }; - - auto SerializeInputOutput = [&](auto& inout) - { - m_stream << UInt32(inout.size()); - for (const auto& data : inout) - { - m_stream << data.name; - SerializeType(data.type); - - m_stream << data.locationIndex.has_value(); - if (data.locationIndex) - m_stream << UInt32(data.locationIndex.value()); - } - }; - - // Conditions - m_stream << UInt32(shader.GetConditionCount()); - for (const auto& cond : shader.GetConditions()) - m_stream << cond.name; - - // Structs - m_stream << UInt32(shader.GetStructCount()); - for (const auto& s : shader.GetStructs()) - { - m_stream << s.name; - m_stream << UInt32(s.members.size()); - for (const auto& member : s.members) - { - m_stream << member.name; - SerializeType(member.type); - } - } - - // Inputs / Outputs - SerializeInputOutput(shader.GetInputs()); - SerializeInputOutput(shader.GetOutputs()); - - // Uniforms - m_stream << UInt32(shader.GetUniformCount()); - for (const auto& uniform : shader.GetUniforms()) - { - m_stream << uniform.name; - SerializeType(uniform.type); - - m_stream << uniform.bindingIndex.has_value(); - if (uniform.bindingIndex) - m_stream << UInt32(uniform.bindingIndex.value()); - - m_stream << uniform.memoryLayout.has_value(); - if (uniform.memoryLayout) - m_stream << UInt32(uniform.memoryLayout.value()); - } - - // Functions - m_stream << UInt32(shader.GetFunctionCount()); - for (const auto& func : shader.GetFunctions()) - { - m_stream << func.name; - SerializeType(func.returnType); - - m_stream << UInt32(func.parameters.size()); - for (const auto& param : func.parameters) - { - m_stream << param.name; - SerializeType(param.type); - } - - Node(func.statement); - } + Node(shader); m_stream.FlushBits(); } @@ -422,9 +222,21 @@ namespace Nz return true; } - void ShaderAstSerializer::Node(ShaderNodes::NodePtr& node) + void ShaderAstSerializer::Node(ExpressionPtr& node) { - ShaderNodes::NodeType nodeType = (node) ? node->GetType() : ShaderNodes::NodeType::None; + NodeType nodeType = (node) ? node->GetType() : NodeType::None; + m_stream << static_cast(nodeType); + + if (node) + { + ShaderSerializerVisitor visitor(*this); + node->Visit(visitor); + } + } + + void ShaderAstSerializer::Node(StatementPtr& node) + { + NodeType nodeType = (node) ? node->GetType() : NodeType::None; m_stream << static_cast(nodeType); if (node) @@ -439,7 +251,7 @@ namespace Nz std::visit([&](auto&& arg) { using T = std::decay_t; - if constexpr (std::is_same_v) + if constexpr (std::is_same_v) { m_stream << UInt8(0); m_stream << UInt32(arg); @@ -454,11 +266,6 @@ namespace Nz }, type); } - void ShaderAstSerializer::Node(const ShaderNodes::NodePtr& node) - { - Node(const_cast(node)); //< Yes const_cast is ugly but it won't be used for writing - } - void ShaderAstSerializer::Value(bool& val) { m_stream << val; @@ -529,19 +336,7 @@ namespace Nz m_stream << val; } - void ShaderAstSerializer::Variable(ShaderNodes::VariablePtr& var) - { - ShaderNodes::VariableType nodeType = (var) ? var->GetType() : ShaderNodes::VariableType::None; - m_stream << static_cast(nodeType); - - if (var) - { - ShaderSerializerVisitor visitor(*this); - var->Visit(visitor); - } - } - - ShaderAst ShaderAstUnserializer::Unserialize() + StatementPtr ShaderAstUnserializer::Unserialize() { UInt32 magicNumber; UInt32 version; @@ -553,122 +348,13 @@ namespace Nz if (version > s_currentVersion) throw std::runtime_error("unsupported version"); - UInt32 shaderStage; - m_stream >> shaderStage; + StatementPtr node; - ShaderAst shader(static_cast(shaderStage)); + Node(node); + if (!node) + throw std::runtime_error("functions can only have statements"); - // Conditions - UInt32 conditionCount; - m_stream >> conditionCount; - for (UInt32 i = 0; i < conditionCount; ++i) - { - std::string conditionName; - Value(conditionName); - - shader.AddCondition(std::move(conditionName)); - } - - // Structs - UInt32 structCount; - m_stream >> structCount; - for (UInt32 i = 0; i < structCount; ++i) - { - std::string structName; - std::vector members; - - Value(structName); - Container(members); - - for (auto& member : members) - { - Value(member.name); - Type(member.type); - } - - shader.AddStruct(std::move(structName), std::move(members)); - } - - // Inputs - UInt32 inputCount; - m_stream >> inputCount; - for (UInt32 i = 0; i < inputCount; ++i) - { - std::string inputName; - ShaderExpressionType inputType; - std::optional location; - - Value(inputName); - Type(inputType); - OptVal(location); - - shader.AddInput(std::move(inputName), std::move(inputType), location); - } - - // Outputs - UInt32 outputCount; - m_stream >> outputCount; - for (UInt32 i = 0; i < outputCount; ++i) - { - std::string outputName; - ShaderExpressionType outputType; - std::optional location; - - Value(outputName); - Type(outputType); - OptVal(location); - - shader.AddOutput(std::move(outputName), std::move(outputType), location); - } - - // Uniforms - UInt32 uniformCount; - m_stream >> uniformCount; - for (UInt32 i = 0; i < uniformCount; ++i) - { - std::string name; - ShaderExpressionType type; - std::optional binding; - std::optional memLayout; - - Value(name); - Type(type); - OptVal(binding); - OptEnum(memLayout); - - shader.AddUniform(std::move(name), std::move(type), std::move(binding), std::move(memLayout)); - } - - // Functions - UInt32 funcCount; - m_stream >> funcCount; - for (UInt32 i = 0; i < funcCount; ++i) - { - std::string name; - ShaderExpressionType retType; - std::vector parameters; - - Value(name); - Type(retType); - - Container(parameters); - for (auto& param : parameters) - { - Value(param.name); - Type(param.type); - } - - ShaderNodes::NodePtr node; - Node(node); - if (!node || !node->IsStatement()) - throw std::runtime_error("functions can only have statements"); - - ShaderNodes::StatementPtr statement = std::static_pointer_cast(node); - - shader.AddFunction(std::move(name), std::move(statement), std::move(parameters), std::move(retType)); - } - - return shader; + return node; } bool ShaderAstUnserializer::IsWriting() const @@ -676,41 +362,50 @@ namespace Nz return false; } - void ShaderAstUnserializer::Node(ShaderNodes::NodePtr& node) + void ShaderAstUnserializer::Node(ExpressionPtr& node) { Int32 nodeTypeInt; m_stream >> nodeTypeInt; - if (nodeTypeInt < static_cast(ShaderNodes::NodeType::None) || nodeTypeInt > static_cast(ShaderNodes::NodeType::Max)) + if (nodeTypeInt < static_cast(NodeType::None) || nodeTypeInt > static_cast(NodeType::Max)) throw std::runtime_error("invalid node type"); - ShaderNodes::NodeType nodeType = static_cast(nodeTypeInt); - -#define HandleType(Type) case ShaderNodes::NodeType:: Type : node = std::make_shared(); break + NodeType nodeType = static_cast(nodeTypeInt); switch (nodeType) { - case ShaderNodes::NodeType::None: break; + case NodeType::None: break; - HandleType(AccessMember); - HandleType(AssignOp); - HandleType(BinaryOp); - HandleType(Branch); - HandleType(Cast); - HandleType(Constant); - HandleType(ConditionalExpression); - HandleType(ConditionalStatement); - HandleType(DeclareVariable); - HandleType(Discard); - HandleType(ExpressionStatement); - HandleType(Identifier); - HandleType(IntrinsicCall); - HandleType(NoOp); - HandleType(ReturnStatement); - HandleType(Sample2D); - HandleType(SwizzleOp); - HandleType(StatementBlock); +#define NAZARA_SHADERAST_EXPRESSION(Node) case NodeType:: Node : node = std::make_unique(); break; +#include + + default: throw std::runtime_error("unexpected node type"); + } + + if (node) + { + ShaderSerializerVisitor visitor(*this); + node->Visit(visitor); + } + } + + void ShaderAstUnserializer::Node(StatementPtr& node) + { + Int32 nodeTypeInt; + m_stream >> nodeTypeInt; + + if (nodeTypeInt < static_cast(NodeType::None) || nodeTypeInt > static_cast(NodeType::Max)) + throw std::runtime_error("invalid node type"); + + NodeType nodeType = static_cast(nodeTypeInt); + switch (nodeType) + { + case NodeType::None: break; + +#define NAZARA_SHADERAST_STATEMENT(Node) case NodeType:: Node : node = std::make_unique(); break; +#include + + default: throw std::runtime_error("unexpected node type"); } -#undef HandleType if (node) { @@ -728,7 +423,7 @@ namespace Nz { case 0: //< Primitive { - ShaderNodes::BasicType exprType; + BasicType exprType; Enum(exprType); type = exprType; @@ -819,36 +514,8 @@ namespace Nz m_stream >> val; } - void ShaderAstUnserializer::Variable(ShaderNodes::VariablePtr& var) - { - Int32 nodeTypeInt; - m_stream >> nodeTypeInt; - ShaderNodes::VariableType nodeType = static_cast(nodeTypeInt); - -#define HandleType(Type) case ShaderNodes::VariableType:: Type : var = std::make_shared(); break - switch (nodeType) - { - case ShaderNodes::VariableType::None: break; - - HandleType(BuiltinVariable); - HandleType(InputVariable); - HandleType(LocalVariable); - HandleType(ParameterVariable); - HandleType(OutputVariable); - HandleType(UniformVariable); - } -#undef HandleType - - if (var) - { - ShaderSerializerVisitor visitor(*this); - var->Visit(visitor); - } - } - - - ByteArray SerializeShader(const ShaderAst& shader) + ByteArray SerializeShader(StatementPtr& shader) { ByteArray byteArray; ByteStream stream(&byteArray, OpenModeFlags(OpenMode_WriteOnly)); @@ -859,7 +526,7 @@ namespace Nz return byteArray; } - ShaderAst UnserializeShader(ByteStream& stream) + StatementPtr UnserializeShader(ByteStream& stream) { ShaderAstUnserializer unserializer(stream); return unserializer.Unserialize(); diff --git a/src/Nazara/Shader/ShaderVarVisitor.cpp b/src/Nazara/Shader/ShaderAstStatementVisitor.cpp similarity index 51% rename from src/Nazara/Shader/ShaderVarVisitor.cpp rename to src/Nazara/Shader/ShaderAstStatementVisitor.cpp index 108d5c69a..6ee90504e 100644 --- a/src/Nazara/Shader/ShaderVarVisitor.cpp +++ b/src/Nazara/Shader/ShaderAstStatementVisitor.cpp @@ -2,15 +2,10 @@ // This file is part of the "Nazara Engine - Shader generator" // For conditions of distribution and use, see copyright notice in Config.hpp -#include +#include #include -namespace Nz +namespace Nz::ShaderAst { - ShaderVarVisitor::~ShaderVarVisitor() = default; - - void ShaderVarVisitor::Visit(const ShaderNodes::VariablePtr& node) - { - node->Visit(*this); - } + AstStatementVisitor::~AstStatementVisitor() = default; } diff --git a/src/Nazara/Shader/ShaderAstStatementVisitorExcept.cpp b/src/Nazara/Shader/ShaderAstStatementVisitorExcept.cpp new file mode 100644 index 000000000..ef4204ce6 --- /dev/null +++ b/src/Nazara/Shader/ShaderAstStatementVisitorExcept.cpp @@ -0,0 +1,15 @@ +// Copyright (C) 2020 Jérôme Leclercq +// This file is part of the "Nazara Engine - Shader generator" +// For conditions of distribution and use, see copyright notice in Config.hpp + +#include +#include + +namespace Nz::ShaderAst +{ +#define NAZARA_SHADERAST_STATEMENT(Node) void StatementVisitorExcept::Visit(ShaderAst::Node& /*node*/) \ + { \ + throw std::runtime_error("unexpected " #Node " node"); \ + } +#include +} diff --git a/src/Nazara/Shader/ShaderAstUtils.cpp b/src/Nazara/Shader/ShaderAstUtils.cpp index d13571ded..c26d38ae5 100644 --- a/src/Nazara/Shader/ShaderAstUtils.cpp +++ b/src/Nazara/Shader/ShaderAstUtils.cpp @@ -5,69 +5,65 @@ #include #include -namespace Nz +namespace Nz::ShaderAst { - ShaderNodes::ExpressionCategory ShaderAstValueCategory::GetExpressionCategory(const ShaderNodes::ExpressionPtr& expression) + ExpressionCategory ShaderAstValueCategory::GetExpressionCategory(Expression& expression) { - Visit(expression); + expression.Visit(*this); return m_expressionCategory; } - void ShaderAstValueCategory::Visit(ShaderNodes::AccessMember& node) + void ShaderAstValueCategory::Visit(AccessMemberExpression& node) { - Visit(node.structExpr); + node.structExpr->Visit(*this); } - void ShaderAstValueCategory::Visit(ShaderNodes::AssignOp& node) + void ShaderAstValueCategory::Visit(AssignExpression& /*node*/) { - m_expressionCategory = ShaderNodes::ExpressionCategory::RValue; + m_expressionCategory = ExpressionCategory::RValue; } - void ShaderAstValueCategory::Visit(ShaderNodes::BinaryOp& node) + void ShaderAstValueCategory::Visit(BinaryExpression& /*node*/) { - m_expressionCategory = ShaderNodes::ExpressionCategory::RValue; + m_expressionCategory = ExpressionCategory::RValue; } - void ShaderAstValueCategory::Visit(ShaderNodes::Cast& node) + void ShaderAstValueCategory::Visit(CastExpression& /*node*/) { - m_expressionCategory = ShaderNodes::ExpressionCategory::RValue; + m_expressionCategory = ExpressionCategory::RValue; } - void ShaderAstValueCategory::Visit(ShaderNodes::ConditionalExpression& node) + void ShaderAstValueCategory::Visit(ConditionalExpression& node) { - Visit(node.truePath); - ShaderNodes::ExpressionCategory trueExprCategory = m_expressionCategory; - Visit(node.falsePath); - ShaderNodes::ExpressionCategory falseExprCategory = m_expressionCategory; + node.truePath->Visit(*this); + ExpressionCategory trueExprCategory = m_expressionCategory; - if (trueExprCategory == ShaderNodes::ExpressionCategory::RValue || falseExprCategory == ShaderNodes::ExpressionCategory::RValue) - m_expressionCategory = ShaderNodes::ExpressionCategory::RValue; + node.falsePath->Visit(*this); + ExpressionCategory falseExprCategory = m_expressionCategory; + + if (trueExprCategory == ExpressionCategory::RValue || falseExprCategory == ExpressionCategory::RValue) + m_expressionCategory = ExpressionCategory::RValue; else - m_expressionCategory = ShaderNodes::ExpressionCategory::LValue; + m_expressionCategory = ExpressionCategory::LValue; } - void ShaderAstValueCategory::Visit(ShaderNodes::Constant& node) + void ShaderAstValueCategory::Visit(ConstantExpression& /*node*/) { - m_expressionCategory = ShaderNodes::ExpressionCategory::RValue; + m_expressionCategory = ExpressionCategory::RValue; } - void ShaderAstValueCategory::Visit(ShaderNodes::Identifier& node) + void ShaderAstValueCategory::Visit(IdentifierExpression& /*node*/) { - m_expressionCategory = ShaderNodes::ExpressionCategory::LValue; + m_expressionCategory = ExpressionCategory::LValue; } - void ShaderAstValueCategory::Visit(ShaderNodes::IntrinsicCall& node) + void ShaderAstValueCategory::Visit(IntrinsicExpression& /*node*/) { - m_expressionCategory = ShaderNodes::ExpressionCategory::RValue; + m_expressionCategory = ExpressionCategory::RValue; } - void ShaderAstValueCategory::Visit(ShaderNodes::Sample2D& node) + void ShaderAstValueCategory::Visit(SwizzleExpression& node) { - m_expressionCategory = ShaderNodes::ExpressionCategory::RValue; - } - - void ShaderAstValueCategory::Visit(ShaderNodes::SwizzleOp& node) - { - Visit(node.expression); + node.expression->Visit(*this); } } diff --git a/src/Nazara/Shader/ShaderAstValidator.cpp b/src/Nazara/Shader/ShaderAstValidator.cpp index f081d7481..0e5d8f52f 100644 --- a/src/Nazara/Shader/ShaderAstValidator.cpp +++ b/src/Nazara/Shader/ShaderAstValidator.cpp @@ -4,48 +4,40 @@ #include #include -#include #include -#include +#include #include #include -namespace Nz +namespace Nz::ShaderAst { struct AstError { std::string errMsg; }; - struct ShaderAstValidator::Context + struct AstValidator::Context { - struct Local - { - std::string name; - ShaderExpressionType type; - }; - - const ShaderAst::Function* currentFunction; - std::vector declaredLocals; - std::vector blockLocalIndex; + //const ShaderAst::Function* currentFunction; + std::optional activeScopeId; + AstCache* cache; }; - bool ShaderAstValidator::Validate(std::string* error) + bool AstValidator::Validate(StatementPtr& node, std::string* error, AstCache* cache) { try { - for (std::size_t i = 0; i < m_shader.GetFunctionCount(); ++i) - { - const auto& func = m_shader.GetFunction(i); + AstCache dummy; - Context currentContext; - currentContext.currentFunction = &func; + Context currentContext; + currentContext.cache = (cache) ? cache : &dummy; - m_context = ¤tContext; - CallOnExit resetContext([&] { m_context = nullptr; }); + m_context = ¤tContext; + CallOnExit resetContext([&] { m_context = nullptr; }); - func.statement->Visit(*this); - } + EnterScope(); + node->Visit(*this); + ExitScope(); return true; } @@ -58,148 +50,183 @@ namespace Nz } } - const ShaderNodes::ExpressionPtr& ShaderAstValidator::MandatoryExpr(const ShaderNodes::ExpressionPtr& node) - { - MandatoryNode(node); - - return node; - } - - const ShaderNodes::NodePtr& ShaderAstValidator::MandatoryNode(const ShaderNodes::NodePtr& node) + Expression& AstValidator::MandatoryExpr(ExpressionPtr& node) { if (!node) - throw AstError{ "Invalid node" }; + throw AstError{ "Invalid expression" }; - return node; + return *node; } - void ShaderAstValidator::TypeMustMatch(const ShaderNodes::ExpressionPtr& left, const ShaderNodes::ExpressionPtr& right) + Statement& AstValidator::MandatoryStatement(StatementPtr& node) { - return TypeMustMatch(left->GetExpressionType(), right->GetExpressionType()); + if (!node) + throw AstError{ "Invalid statement" }; + + return *node; } - void ShaderAstValidator::TypeMustMatch(const ShaderExpressionType& left, const ShaderExpressionType& right) + void AstValidator::TypeMustMatch(ExpressionPtr& left, ExpressionPtr& right) + { + return TypeMustMatch(GetExpressionType(*left, m_context->cache), GetExpressionType(*right, m_context->cache)); + } + + void AstValidator::TypeMustMatch(const ShaderExpressionType& left, const ShaderExpressionType& right) { if (left != right) throw AstError{ "Left expression type must match right expression type" }; } - const ShaderAst::StructMember& ShaderAstValidator::CheckField(const std::string& structName, std::size_t* memberIndex, std::size_t remainingMembers) + ShaderExpressionType AstValidator::CheckField(const std::string& structName, const std::string* memberIdentifier, std::size_t remainingMembers) { - const auto& structs = m_shader.GetStructs(); - auto it = std::find_if(structs.begin(), structs.end(), [&](const auto& s) { return s.name == structName; }); - if (it == structs.end()) - throw AstError{ "invalid structure" }; + const AstCache::Identifier* identifier = m_context->cache->FindIdentifier(*m_context->activeScopeId, structName); + if (!identifier) + throw AstError{ "unknown identifier " + structName }; - const ShaderAst::Struct& s = *it; - if (*memberIndex >= s.members.size()) - throw AstError{ "member index out of bounds" }; + if (std::holds_alternative(identifier->value)) + throw AstError{ "identifier is not a struct" }; - const auto& member = s.members[*memberIndex]; + const StructDescription& s = std::get(identifier->value); + + auto memberIt = std::find_if(s.members.begin(), s.members.begin(), [&](const auto& field) { return field.name == memberIdentifier[0]; }); + if (memberIt == s.members.end()) + throw AstError{ "unknown field " + memberIdentifier[0]}; + + const auto& member = *memberIt; if (remainingMembers > 1) - { - if (!IsStructType(member.type)) - throw AstError{ "member type does not match node type" }; - - return CheckField(std::get(member.type), memberIndex + 1, remainingMembers - 1); - } + return CheckField(std::get(member.type), memberIdentifier + 1, remainingMembers - 1); else - return member; + return member.type; } - void ShaderAstValidator::Visit(ShaderNodes::AccessMember& node) + AstCache::Scope& AstValidator::EnterScope() { - const ShaderExpressionType& exprType = MandatoryExpr(node.structExpr)->GetExpressionType(); + std::size_t newScopeId = m_context->cache->scopes.size(); + + std::optional previousScope = m_context->activeScopeId; + + auto& newScope = m_context->cache->scopes.emplace_back(); + newScope.parentScopeIndex = previousScope; + + m_context->activeScopeId = newScopeId; + return m_context->cache->scopes[newScopeId]; + } + + void AstValidator::ExitScope() + { + assert(m_context->activeScopeId); + auto& previousScope = m_context->cache->scopes[*m_context->activeScopeId]; + m_context->activeScopeId = previousScope.parentScopeIndex; + } + + void AstValidator::RegisterExpressionType(Expression& node, ShaderExpressionType expressionType) + { + m_context->cache->nodeExpressionType[&node] = std::move(expressionType); + } + + void AstValidator::RegisterScope(Node& node) + { + if (m_context->activeScopeId) + m_context->cache->scopeIdByNode[&node] = *m_context->activeScopeId; + } + + void AstValidator::Visit(AccessMemberExpression& node) + { + RegisterScope(node); + + const ShaderExpressionType& exprType = GetExpressionType(MandatoryExpr(node.structExpr), m_context->cache); if (!IsStructType(exprType)) throw AstError{ "expression is not a structure" }; const std::string& structName = std::get(exprType); - const auto& member = CheckField(structName, node.memberIndices.data(), node.memberIndices.size()); - if (member.type != node.exprType) - throw AstError{ "member type does not match node type" }; + RegisterExpressionType(node, CheckField(structName, node.memberIdentifiers.data(), node.memberIdentifiers.size())); } - void ShaderAstValidator::Visit(ShaderNodes::AssignOp& node) + void AstValidator::Visit(AssignExpression& node) { - MandatoryNode(node.left); - MandatoryNode(node.right); + RegisterScope(node); + + MandatoryExpr(node.left); + MandatoryExpr(node.right); TypeMustMatch(node.left, node.right); - if (GetExpressionCategory(node.left) != ShaderNodes::ExpressionCategory::LValue) + if (GetExpressionCategory(*node.left) != ExpressionCategory::LValue) throw AstError { "Assignation is only possible with a l-value" }; - ShaderAstRecursiveVisitor::Visit(node); + AstRecursiveVisitor::Visit(node); } - void ShaderAstValidator::Visit(ShaderNodes::BinaryOp& node) + void AstValidator::Visit(BinaryExpression& node) { - MandatoryNode(node.left); - MandatoryNode(node.right); + RegisterScope(node); - const ShaderExpressionType& leftExprType = MandatoryExpr(node.left)->GetExpressionType(); + // Register expression type + AstRecursiveVisitor::Visit(node); + + const ShaderExpressionType& leftExprType = GetExpressionType(MandatoryExpr(node.left), m_context->cache); if (!IsBasicType(leftExprType)) throw AstError{ "left expression type does not support binary operation" }; - const ShaderExpressionType& rightExprType = MandatoryExpr(node.right)->GetExpressionType(); + const ShaderExpressionType& rightExprType = GetExpressionType(MandatoryExpr(node.right), m_context->cache); if (!IsBasicType(rightExprType)) throw AstError{ "right expression type does not support binary operation" }; - ShaderNodes::BasicType leftType = std::get(leftExprType); - ShaderNodes::BasicType rightType = std::get(rightExprType); + BasicType leftType = std::get(leftExprType); + BasicType rightType = std::get(rightExprType); switch (node.op) { - case ShaderNodes::BinaryType::CompGe: - case ShaderNodes::BinaryType::CompGt: - case ShaderNodes::BinaryType::CompLe: - case ShaderNodes::BinaryType::CompLt: - if (leftType == ShaderNodes::BasicType::Boolean) + case BinaryType::CompGe: + case BinaryType::CompGt: + case BinaryType::CompLe: + case BinaryType::CompLt: + if (leftType == BasicType::Boolean) throw AstError{ "this operation is not supported for booleans" }; [[fallthrough]]; - case ShaderNodes::BinaryType::Add: - case ShaderNodes::BinaryType::CompEq: - case ShaderNodes::BinaryType::CompNe: - case ShaderNodes::BinaryType::Subtract: + case BinaryType::Add: + case BinaryType::CompEq: + case BinaryType::CompNe: + case BinaryType::Subtract: TypeMustMatch(node.left, node.right); break; - case ShaderNodes::BinaryType::Multiply: - case ShaderNodes::BinaryType::Divide: + case BinaryType::Multiply: + case BinaryType::Divide: { switch (leftType) { - case ShaderNodes::BasicType::Float1: - case ShaderNodes::BasicType::Int1: + case BasicType::Float1: + case BasicType::Int1: { - if (ShaderNodes::Node::GetComponentType(rightType) != leftType) + if (GetComponentType(rightType) != leftType) throw AstError{ "Left expression type is not compatible with right expression type" }; break; } - case ShaderNodes::BasicType::Float2: - case ShaderNodes::BasicType::Float3: - case ShaderNodes::BasicType::Float4: - case ShaderNodes::BasicType::Int2: - case ShaderNodes::BasicType::Int3: - case ShaderNodes::BasicType::Int4: + case BasicType::Float2: + case BasicType::Float3: + case BasicType::Float4: + case BasicType::Int2: + case BasicType::Int3: + case BasicType::Int4: { - if (leftType != rightType && rightType != ShaderNodes::Node::GetComponentType(leftType)) + if (leftType != rightType && rightType != GetComponentType(leftType)) throw AstError{ "Left expression type is not compatible with right expression type" }; break; } - case ShaderNodes::BasicType::Mat4x4: + case BasicType::Mat4x4: { switch (rightType) { - case ShaderNodes::BasicType::Float1: - case ShaderNodes::BasicType::Float4: - case ShaderNodes::BasicType::Mat4x4: + case BasicType::Float1: + case BasicType::Float4: + case BasicType::Mat4x4: break; default: @@ -211,120 +238,86 @@ namespace Nz default: TypeMustMatch(node.left, node.right); + break; } } } - - ShaderAstRecursiveVisitor::Visit(node); } - void ShaderAstValidator::Visit(ShaderNodes::Branch& node) + void AstValidator::Visit(CastExpression& node) { - for (const auto& condStatement : node.condStatements) - { - const ShaderExpressionType& condType = MandatoryExpr(condStatement.condition)->GetExpressionType(); - if (!IsBasicType(condType) || std::get(condType) != ShaderNodes::BasicType::Boolean) - throw AstError{ "if expression must resolve to boolean type" }; + RegisterScope(node); - MandatoryNode(condStatement.statement); - } - - ShaderAstRecursiveVisitor::Visit(node); - } - - void ShaderAstValidator::Visit(ShaderNodes::Cast& node) - { unsigned int componentCount = 0; - unsigned int requiredComponents = node.GetComponentCount(node.exprType); - for (const auto& exprPtr : node.expressions) + unsigned int requiredComponents = GetComponentCount(node.targetType); + for (auto& exprPtr : node.expressions) { if (!exprPtr) break; - const ShaderExpressionType& exprType = exprPtr->GetExpressionType(); + ShaderExpressionType exprType = GetExpressionType(*exprPtr, m_context->cache); if (!IsBasicType(exprType)) throw AstError{ "incompatible type" }; - componentCount += node.GetComponentCount(std::get(exprType)); + componentCount += GetComponentCount(std::get(exprType)); } if (componentCount != requiredComponents) throw AstError{ "component count doesn't match required component count" }; - ShaderAstRecursiveVisitor::Visit(node); + AstRecursiveVisitor::Visit(node); } - void ShaderAstValidator::Visit(ShaderNodes::ConditionalExpression& node) + void AstValidator::Visit(ConditionalExpression& node) { - MandatoryNode(node.truePath); - MandatoryNode(node.falsePath); + MandatoryExpr(node.truePath); + MandatoryExpr(node.falsePath); - if (m_shader.FindConditionByName(node.conditionName) == ShaderAst::InvalidCondition) - throw AstError{ "condition not found" }; + RegisterScope(node); + + AstRecursiveVisitor::Visit(node); + //if (m_shader.FindConditionByName(node.conditionName) == ShaderAst::InvalidCondition) + // throw AstError{ "condition not found" }; } - void ShaderAstValidator::Visit(ShaderNodes::ConditionalStatement& node) + void AstValidator::Visit(ConstantExpression& node) { - MandatoryNode(node.statement); - - if (m_shader.FindConditionByName(node.conditionName) == ShaderAst::InvalidCondition) - throw AstError{ "condition not found" }; + RegisterScope(node); } - void ShaderAstValidator::Visit(ShaderNodes::Constant& /*node*/) - { - } - - void ShaderAstValidator::Visit(ShaderNodes::DeclareVariable& node) + void AstValidator::Visit(IdentifierExpression& node) { assert(m_context); - if (node.variable->GetType() != ShaderNodes::VariableType::LocalVariable) - throw AstError{ "Only local variables can be declared in a statement" }; + if (!m_context->activeScopeId) + throw AstError{ "no scope" }; - const auto& localVar = static_cast(*node.variable); + RegisterScope(node); - auto& local = m_context->declaredLocals.emplace_back(); - local.name = localVar.name; - local.type = localVar.type; - - ShaderAstRecursiveVisitor::Visit(node); + const AstCache::Identifier* identifier = m_context->cache->FindIdentifier(*m_context->activeScopeId, node.identifier); + if (!identifier) + throw AstError{ "Unknown variable " + node.identifier }; } - - void ShaderAstValidator::Visit(ShaderNodes::ExpressionStatement& node) + + void AstValidator::Visit(IntrinsicExpression& node) { - MandatoryNode(node.expression); + RegisterScope(node); - ShaderAstRecursiveVisitor::Visit(node); - } - - void ShaderAstValidator::Visit(ShaderNodes::Identifier& node) - { - assert(m_context); - - if (!node.var) - throw AstError{ "Invalid variable" }; - - Visit(node.var); - } - - void ShaderAstValidator::Visit(ShaderNodes::IntrinsicCall& node) - { switch (node.intrinsic) { - case ShaderNodes::IntrinsicType::CrossProduct: - case ShaderNodes::IntrinsicType::DotProduct: + case IntrinsicType::CrossProduct: + case IntrinsicType::DotProduct: { if (node.parameters.size() != 2) throw AstError { "Expected 2 parameters" }; for (auto& param : node.parameters) - MandatoryNode(param); + MandatoryExpr(param); - ShaderExpressionType type = node.parameters.front()->GetExpressionType(); + ShaderExpressionType type = GetExpressionType(*node.parameters.front(), m_context->cache); for (std::size_t i = 1; i < node.parameters.size(); ++i) { - if (type != node.parameters[i]->GetExpressionType()) + if (type != GetExpressionType(MandatoryExpr(node.parameters[i])), m_context->cache) throw AstError{ "All type must match" }; } @@ -334,180 +327,176 @@ namespace Nz switch (node.intrinsic) { - case ShaderNodes::IntrinsicType::CrossProduct: + case IntrinsicType::CrossProduct: { - if (node.parameters[0]->GetExpressionType() != ShaderExpressionType{ ShaderNodes::BasicType::Float3 }) + if (GetExpressionType(*node.parameters[0]) != ShaderExpressionType{ BasicType::Float3 }, m_context->cache) throw AstError{ "CrossProduct only works with Float3 expressions" }; break; } - case ShaderNodes::IntrinsicType::DotProduct: + case IntrinsicType::DotProduct: break; } - ShaderAstRecursiveVisitor::Visit(node); + AstRecursiveVisitor::Visit(node); } - void ShaderAstValidator::Visit(ShaderNodes::ReturnStatement& node) + void AstValidator::Visit(SwizzleExpression& node) { - if (m_context->currentFunction->returnType != ShaderExpressionType(ShaderNodes::BasicType::Void)) - { - if (MandatoryExpr(node.returnExpr)->GetExpressionType() != m_context->currentFunction->returnType) - throw AstError{ "Return type doesn't match function return type" }; - } - else - { - if (node.returnExpr) - throw AstError{ "Unexpected expression for return (function doesn't return)" }; - } + RegisterScope(node); - ShaderAstRecursiveVisitor::Visit(node); - } - - void ShaderAstValidator::Visit(ShaderNodes::Sample2D& node) - { - if (MandatoryExpr(node.sampler)->GetExpressionType() != ShaderExpressionType{ ShaderNodes::BasicType::Sampler2D }) - throw AstError{ "Sampler must be a Sampler2D" }; - - if (MandatoryExpr(node.coordinates)->GetExpressionType() != ShaderExpressionType{ ShaderNodes::BasicType::Float2 }) - throw AstError{ "Coordinates must be a Float2" }; - - ShaderAstRecursiveVisitor::Visit(node); - } - - void ShaderAstValidator::Visit(ShaderNodes::StatementBlock& node) - { - assert(m_context); - - m_context->blockLocalIndex.push_back(m_context->declaredLocals.size()); - - for (const auto& statement : node.statements) - MandatoryNode(statement); - - assert(m_context->declaredLocals.size() >= m_context->blockLocalIndex.back()); - m_context->declaredLocals.resize(m_context->blockLocalIndex.back()); - m_context->blockLocalIndex.pop_back(); - - ShaderAstRecursiveVisitor::Visit(node); - } - - void ShaderAstValidator::Visit(ShaderNodes::SwizzleOp& node) - { if (node.componentCount > 4) throw AstError{ "Cannot swizzle more than four elements" }; - const ShaderExpressionType& exprType = MandatoryExpr(node.expression)->GetExpressionType(); + const ShaderExpressionType& exprType = GetExpressionType(MandatoryExpr(node.expression), m_context->cache); if (!IsBasicType(exprType)) throw AstError{ "Cannot swizzle this type" }; - switch (std::get(exprType)) + switch (std::get(exprType)) { - case ShaderNodes::BasicType::Float1: - case ShaderNodes::BasicType::Float2: - case ShaderNodes::BasicType::Float3: - case ShaderNodes::BasicType::Float4: - case ShaderNodes::BasicType::Int1: - case ShaderNodes::BasicType::Int2: - case ShaderNodes::BasicType::Int3: - case ShaderNodes::BasicType::Int4: + case BasicType::Float1: + case BasicType::Float2: + case BasicType::Float3: + case BasicType::Float4: + case BasicType::Int1: + case BasicType::Int2: + case BasicType::Int3: + case BasicType::Int4: break; default: throw AstError{ "Cannot swizzle this type" }; } - ShaderAstRecursiveVisitor::Visit(node); + AstRecursiveVisitor::Visit(node); } - void ShaderAstValidator::Visit(ShaderNodes::BuiltinVariable& var) + void AstValidator::Visit(BranchStatement& node) { - switch (var.entry) + RegisterScope(node); + + for (auto& condStatement : node.condStatements) { - case ShaderNodes::BuiltinEntry::VertexPosition: - if (!IsBasicType(var.type) || - std::get(var.type) != ShaderNodes::BasicType::Float4) - throw AstError{ "Builtin is not of the expected type" }; + const ShaderExpressionType& condType = GetExpressionType(MandatoryExpr(condStatement.condition), m_context->cache); + if (!IsBasicType(condType) || std::get(condType) != BasicType::Boolean) + throw AstError{ "if expression must resolve to boolean type" }; - break; - - default: - break; - } - } - - void ShaderAstValidator::Visit(ShaderNodes::InputVariable& var) - { - for (std::size_t i = 0; i < m_shader.GetInputCount(); ++i) - { - const auto& input = m_shader.GetInput(i); - if (input.name == var.name) - { - TypeMustMatch(input.type, var.type); - return; - } + MandatoryStatement(condStatement.statement); } - throw AstError{ "Input not found" }; + AstRecursiveVisitor::Visit(node); } - void ShaderAstValidator::Visit(ShaderNodes::LocalVariable& var) + void AstValidator::Visit(ConditionalStatement& node) { - const auto& vars = m_context->declaredLocals; + MandatoryStatement(node.statement); - auto it = std::find_if(vars.begin(), vars.end(), [&](const auto& v) { return v.name == var.name; }); - if (it == vars.end()) - throw AstError{ "Local variable not found in this block" }; + RegisterScope(node); - TypeMustMatch(it->type, var.type); + AstRecursiveVisitor::Visit(node); + //if (m_shader.FindConditionByName(node.conditionName) == ShaderAst::InvalidCondition) + // throw AstError{ "condition not found" }; } - void ShaderAstValidator::Visit(ShaderNodes::OutputVariable& var) + void AstValidator::Visit(DeclareFunctionStatement& node) { - for (std::size_t i = 0; i < m_shader.GetOutputCount(); ++i) + auto& scope = EnterScope(); + + RegisterScope(node); + + for (auto& parameter : node.parameters) { - const auto& input = m_shader.GetOutput(i); - if (input.name == var.name) - { - TypeMustMatch(input.type, var.type); - return; - } + auto& identifier = scope.identifiers.emplace_back(); + identifier = AstCache::Identifier{ parameter.name, AstCache::Variable { parameter.type } }; } - throw AstError{ "Output not found" }; + for (auto& statement : node.statements) + MandatoryStatement(statement).Visit(*this); + + ExitScope(); } - void ShaderAstValidator::Visit(ShaderNodes::ParameterVariable& var) + void AstValidator::Visit(DeclareStructStatement& node) { - assert(m_context->currentFunction); + assert(m_context); - const auto& parameters = m_context->currentFunction->parameters; + if (!m_context->activeScopeId) + throw AstError{ "cannot declare variable without scope" }; - auto it = std::find_if(parameters.begin(), parameters.end(), [&](const auto& parameter) { return parameter.name == var.name; }); - if (it == parameters.end()) - throw AstError{ "Parameter not found in function" }; + RegisterScope(node); - TypeMustMatch(it->type, var.type); + auto& scope = m_context->cache->scopes[*m_context->activeScopeId]; + + auto& identifier = scope.identifiers.emplace_back(); + identifier = AstCache::Identifier{ node.description.name, node.description }; + + AstRecursiveVisitor::Visit(node); } - void ShaderAstValidator::Visit(ShaderNodes::UniformVariable& var) + void AstValidator::Visit(DeclareVariableStatement& node) { - for (std::size_t i = 0; i < m_shader.GetUniformCount(); ++i) + assert(m_context); + + if (!m_context->activeScopeId) + throw AstError{ "cannot declare variable without scope" }; + + RegisterScope(node); + + auto& scope = m_context->cache->scopes[*m_context->activeScopeId]; + + auto& identifier = scope.identifiers.emplace_back(); + identifier = AstCache::Identifier{ node.varName, AstCache::Variable { node.varType } }; + + AstRecursiveVisitor::Visit(node); + } + + void AstValidator::Visit(ExpressionStatement& node) + { + RegisterScope(node); + + MandatoryExpr(node.expression); + + AstRecursiveVisitor::Visit(node); + } + + void AstValidator::Visit(MultiStatement& node) + { + assert(m_context); + + EnterScope(); + + RegisterScope(node); + + for (auto& statement : node.statements) + MandatoryStatement(statement); + + ExitScope(); + + AstRecursiveVisitor::Visit(node); + } + + void AstValidator::Visit(ReturnStatement& node) + { + RegisterScope(node); + + /*if (m_context->currentFunction->returnType != ShaderExpressionType(BasicType::Void)) { - const auto& uniform = m_shader.GetUniform(i); - if (uniform.name == var.name) - { - TypeMustMatch(uniform.type, var.type); - return; - } + if (GetExpressionType(MandatoryExpr(node.returnExpr)) != m_context->currentFunction->returnType) + throw AstError{ "Return type doesn't match function return type" }; } + else + { + if (node.returnExpr) + throw AstError{ "Unexpected expression for return (function doesn't return)" }; + }*/ - throw AstError{ "Uniform not found" }; + AstRecursiveVisitor::Visit(node); } - bool ValidateShader(const ShaderAst& shader, std::string* error) + bool ValidateAst(StatementPtr& node, std::string* error, AstCache* cache) { - ShaderAstValidator validator(shader); - return validator.Validate(error); + AstValidator validator; + return validator.Validate(node, error, cache); } } diff --git a/src/Nazara/Shader/ShaderAstVisitorExcept.cpp b/src/Nazara/Shader/ShaderAstVisitorExcept.cpp deleted file mode 100644 index b8f8c7d16..000000000 --- a/src/Nazara/Shader/ShaderAstVisitorExcept.cpp +++ /dev/null @@ -1,100 +0,0 @@ -// Copyright (C) 2020 Jérôme Leclercq -// This file is part of the "Nazara Engine - Shader generator" -// For conditions of distribution and use, see copyright notice in Config.hpp - -#include -#include -#include - -namespace Nz -{ - void ShaderAstVisitorExcept::Visit(ShaderNodes::AccessMember& /*node*/) - { - throw std::runtime_error("unhandled AccessMember node"); - } - - void ShaderAstVisitorExcept::Visit(ShaderNodes::AssignOp& /*node*/) - { - throw std::runtime_error("unhandled AssignOp node"); - } - - void ShaderAstVisitorExcept::Visit(ShaderNodes::BinaryOp& /*node*/) - { - throw std::runtime_error("unhandled AccessMember node"); - } - - void ShaderAstVisitorExcept::Visit(ShaderNodes::Branch& /*node*/) - { - throw std::runtime_error("unhandled Branch node"); - } - - void ShaderAstVisitorExcept::Visit(ShaderNodes::Cast& /*node*/) - { - throw std::runtime_error("unhandled Cast node"); - } - - void ShaderAstVisitorExcept::Visit(ShaderNodes::ConditionalExpression& /*node*/) - { - throw std::runtime_error("unhandled ConditionalExpression node"); - } - - void ShaderAstVisitorExcept::Visit(ShaderNodes::ConditionalStatement& /*node*/) - { - throw std::runtime_error("unhandled ConditionalStatement node"); - } - - void ShaderAstVisitorExcept::Visit(ShaderNodes::Constant& /*node*/) - { - throw std::runtime_error("unhandled Constant node"); - } - - void ShaderAstVisitorExcept::Visit(ShaderNodes::DeclareVariable& /*node*/) - { - throw std::runtime_error("unhandled DeclareVariable node"); - } - - void ShaderAstVisitorExcept::Visit(ShaderNodes::Discard& /*node*/) - { - throw std::runtime_error("unhandled Discard node"); - } - - void ShaderAstVisitorExcept::Visit(ShaderNodes::ExpressionStatement& /*node*/) - { - throw std::runtime_error("unhandled ExpressionStatement node"); - } - - void ShaderAstVisitorExcept::Visit(ShaderNodes::Identifier& /*node*/) - { - throw std::runtime_error("unhandled Identifier node"); - } - - void ShaderAstVisitorExcept::Visit(ShaderNodes::IntrinsicCall& /*node*/) - { - throw std::runtime_error("unhandled IntrinsicCall node"); - } - - void ShaderAstVisitorExcept::Visit(ShaderNodes::NoOp& node) - { - throw std::runtime_error("unhandled NoOp node"); - } - - void ShaderAstVisitorExcept::Visit(ShaderNodes::ReturnStatement& node) - { - throw std::runtime_error("unhandled ReturnStatement node"); - } - - void ShaderAstVisitorExcept::Visit(ShaderNodes::Sample2D& /*node*/) - { - throw std::runtime_error("unhandled Sample2D node"); - } - - void ShaderAstVisitorExcept::Visit(ShaderNodes::StatementBlock& /*node*/) - { - throw std::runtime_error("unhandled StatementBlock node"); - } - - void ShaderAstVisitorExcept::Visit(ShaderNodes::SwizzleOp& /*node*/) - { - throw std::runtime_error("unhandled SwizzleOp node"); - } -} diff --git a/src/Nazara/Shader/ShaderLangLexer.cpp b/src/Nazara/Shader/ShaderLangLexer.cpp index 8782a54fb..13b21b0c4 100644 --- a/src/Nazara/Shader/ShaderLangLexer.cpp +++ b/src/Nazara/Shader/ShaderLangLexer.cpp @@ -42,6 +42,7 @@ namespace Nz::ShaderLang std::unordered_map reservedKeywords = { { "false", TokenType::BoolFalse }, { "fn", TokenType::FunctionDeclaration }, + { "let", TokenType::Let }, { "return", TokenType::Return }, { "true", TokenType::BoolTrue } }; @@ -143,7 +144,7 @@ namespace Nz::ShaderLang while (next != -1); } else - tokenType == TokenType::Divide; + tokenType = TokenType::Divide; break; } @@ -191,9 +192,11 @@ namespace Nz::ShaderLang std::string valueStr(str.substr(start, currentPos - start + 1)); + const char* ptr = valueStr.c_str(); + char* end; - double value = std::strtod(valueStr.c_str(), &end); - if (end != &str[currentPos + 1]) + double value = std::strtod(ptr, &end); + if (end != &ptr[valueStr.size()]) throw BadNumber{}; token.data = value; @@ -218,6 +221,7 @@ namespace Nz::ShaderLang break; } + case '=': tokenType = TokenType::Assign; break; case '+': tokenType = TokenType::Plus; break; case '*': tokenType = TokenType::Multiply; break; case ':': tokenType = TokenType::Colon; break; diff --git a/src/Nazara/Shader/ShaderLangParser.cpp b/src/Nazara/Shader/ShaderLangParser.cpp index cea55ee2e..050d88f6d 100644 --- a/src/Nazara/Shader/ShaderLangParser.cpp +++ b/src/Nazara/Shader/ShaderLangParser.cpp @@ -3,6 +3,7 @@ // For conditions of distribution and use, see copyright notice in Config.hpp #include +#include #include #include @@ -10,36 +11,38 @@ namespace Nz::ShaderLang { namespace { - std::unordered_map identifierToBasicType = { - { "bool", ShaderNodes::BasicType::Boolean }, + std::unordered_map identifierToBasicType = { + { "bool", ShaderAst::BasicType::Boolean }, - { "i32", ShaderNodes::BasicType::Int1 }, - { "vec2i32", ShaderNodes::BasicType::Int2 }, - { "vec3i32", ShaderNodes::BasicType::Int3 }, - { "vec4i32", ShaderNodes::BasicType::Int4 }, + { "i32", ShaderAst::BasicType::Int1 }, + { "vec2i32", ShaderAst::BasicType::Int2 }, + { "vec3i32", ShaderAst::BasicType::Int3 }, + { "vec4i32", ShaderAst::BasicType::Int4 }, - { "f32", ShaderNodes::BasicType::Float1 }, - { "vec2f32", ShaderNodes::BasicType::Float2 }, - { "vec3f32", ShaderNodes::BasicType::Float3 }, - { "vec4f32", ShaderNodes::BasicType::Float4 }, + { "f32", ShaderAst::BasicType::Float1 }, + { "vec2f32", ShaderAst::BasicType::Float2 }, + { "vec3f32", ShaderAst::BasicType::Float3 }, + { "vec4f32", ShaderAst::BasicType::Float4 }, - { "mat4x4f32", ShaderNodes::BasicType::Mat4x4 }, - { "sampler2D", ShaderNodes::BasicType::Sampler2D }, - { "void", ShaderNodes::BasicType::Void }, + { "mat4x4f32", ShaderAst::BasicType::Mat4x4 }, + { "sampler2D", ShaderAst::BasicType::Sampler2D }, + { "void", ShaderAst::BasicType::Void }, - { "u32", ShaderNodes::BasicType::UInt1 }, - { "vec2u32", ShaderNodes::BasicType::UInt3 }, - { "vec3u32", ShaderNodes::BasicType::UInt3 }, - { "vec4u32", ShaderNodes::BasicType::UInt4 }, + { "u32", ShaderAst::BasicType::UInt1 }, + { "vec2u32", ShaderAst::BasicType::UInt3 }, + { "vec3u32", ShaderAst::BasicType::UInt3 }, + { "vec4u32", ShaderAst::BasicType::UInt4 }, }; } - ShaderAst Parser::Parse(const std::vector& tokens) + ShaderAst::StatementPtr Parser::Parse(const std::vector& tokens) { Context context; context.tokenCount = tokens.size(); context.tokens = tokens.data(); + context.root = std::make_unique(); + m_context = &context; m_context->tokenIndex = -1; @@ -51,7 +54,7 @@ namespace Nz::ShaderLang switch (nextToken.type) { case TokenType::FunctionDeclaration: - ParseFunctionDeclaration(); + context.root->statements.push_back(ParseFunctionDeclaration()); break; case TokenType::EndOfStream: @@ -63,7 +66,7 @@ namespace Nz::ShaderLang } } - return std::move(context.result); + return std::move(context.root); } const Token& Parser::Advance() @@ -92,12 +95,12 @@ namespace Nz::ShaderLang return m_context->tokens[m_context->tokenIndex + 1]; } - ShaderNodes::StatementPtr Parser::ParseFunctionBody() + std::vector Parser::ParseFunctionBody() { return ParseStatementList(); } - void Parser::ParseFunctionDeclaration() + ShaderAst::StatementPtr Parser::ParseFunctionDeclaration() { ExpectNext(TokenType::FunctionDeclaration); @@ -105,7 +108,7 @@ namespace Nz::ShaderLang ExpectNext(TokenType::OpenParenthesis); - std::vector parameters; + std::vector parameters; bool firstParameter = true; for (;;) @@ -126,7 +129,7 @@ namespace Nz::ShaderLang ExpectNext(TokenType::ClosingParenthesis); - ShaderExpressionType returnType = ShaderNodes::BasicType::Void; + ShaderAst::ShaderExpressionType returnType = ShaderAst::BasicType::Void; if (PeekNext().type == TokenType::FunctionReturn) { Advance(); //< Consume -> @@ -136,42 +139,46 @@ namespace Nz::ShaderLang ExpectNext(TokenType::OpenCurlyBracket); - ShaderNodes::StatementPtr functionBody = ParseFunctionBody(); + std::vector functionBody = ParseFunctionBody(); ExpectNext(TokenType::ClosingCurlyBracket); - m_context->result.AddFunction(functionName, functionBody, std::move(parameters), returnType); + return ShaderBuilder::DeclareFunction(std::move(functionName), std::move(parameters), std::move(functionBody), std::move(returnType)); } - ShaderAst::FunctionParameter Parser::ParseFunctionParameter() + ShaderAst::DeclareFunctionStatement::Parameter Parser::ParseFunctionParameter() { std::string parameterName = ParseIdentifierAsName(); ExpectNext(TokenType::Colon); - ShaderExpressionType parameterType = ParseIdentifierAsType(); + ShaderAst::ShaderExpressionType parameterType = ParseIdentifierAsType(); return { parameterName, parameterType }; } - ShaderNodes::StatementPtr Parser::ParseReturnStatement() + ShaderAst::StatementPtr Parser::ParseReturnStatement() { ExpectNext(TokenType::Return); - ShaderNodes::ExpressionPtr expr; + ShaderAst::ExpressionPtr expr; if (PeekNext().type != TokenType::Semicolon) expr = ParseExpression(); - return ShaderNodes::ReturnStatement::Build(std::move(expr)); + return ShaderBuilder::Return(std::move(expr)); } - ShaderNodes::StatementPtr Parser::ParseStatement() + ShaderAst::StatementPtr Parser::ParseStatement() { const Token& token = PeekNext(); - ShaderNodes::StatementPtr statement; + ShaderAst::StatementPtr statement; switch (token.type) { + case TokenType::Let: + statement = ParseVariableDeclaration(); + break; + case TokenType::Return: statement = ParseReturnStatement(); break; @@ -185,18 +192,38 @@ namespace Nz::ShaderLang return statement; } - ShaderNodes::StatementPtr Parser::ParseStatementList() + std::vector Parser::ParseStatementList() { - std::vector statements; + std::vector statements; while (PeekNext().type != TokenType::ClosingCurlyBracket) { statements.push_back(ParseStatement()); } - return ShaderNodes::StatementBlock::Build(std::move(statements)); + return statements; } - ShaderNodes::ExpressionPtr Parser::ParseBinOpRhs(int exprPrecedence, ShaderNodes::ExpressionPtr lhs) + ShaderAst::StatementPtr Parser::ParseVariableDeclaration() + { + ExpectNext(TokenType::Let); + + std::string variableName = ParseIdentifierAsName(); + + ExpectNext(TokenType::Colon); + + ShaderAst::ShaderExpressionType variableType = ParseIdentifierAsType(); + + ShaderAst::ExpressionPtr expression; + if (PeekNext().type == TokenType::Assign) + { + Advance(); + expression = ParseExpression(); + } + + return ShaderBuilder::DeclareVariable(std::move(variableName), std::move(variableType), std::move(expression)); + } + + ShaderAst::ExpressionPtr Parser::ParseBinOpRhs(int exprPrecedence, ShaderAst::ExpressionPtr lhs) { for (;;) { @@ -207,7 +234,7 @@ namespace Nz::ShaderLang return lhs; Advance(); - ShaderNodes::ExpressionPtr rhs = ParsePrimaryExpression(); + ShaderAst::ExpressionPtr rhs = ParsePrimaryExpression(); const Token& nextOp = PeekNext(); @@ -215,57 +242,58 @@ namespace Nz::ShaderLang if (tokenPrecedence < nextTokenPrecedence) rhs = ParseBinOpRhs(tokenPrecedence + 1, std::move(rhs)); - ShaderNodes::BinaryType binaryType; + ShaderAst::BinaryType binaryType; { switch (currentOp.type) { - case TokenType::Plus: binaryType = ShaderNodes::BinaryType::Add; break; - case TokenType::Minus: binaryType = ShaderNodes::BinaryType::Subtract; break; - case TokenType::Multiply: binaryType = ShaderNodes::BinaryType::Multiply; break; - case TokenType::Divide: binaryType = ShaderNodes::BinaryType::Divide; break; + case TokenType::Plus: binaryType = ShaderAst::BinaryType::Add; break; + case TokenType::Minus: binaryType = ShaderAst::BinaryType::Subtract; break; + case TokenType::Multiply: binaryType = ShaderAst::BinaryType::Multiply; break; + case TokenType::Divide: binaryType = ShaderAst::BinaryType::Divide; break; default: throw UnexpectedToken{}; } } - lhs = ShaderNodes::BinaryOp::Build(binaryType, std::move(lhs), std::move(rhs)); + lhs = ShaderBuilder::Binary(binaryType, std::move(lhs), std::move(rhs)); } } - ShaderNodes::ExpressionPtr Parser::ParseExpression() + ShaderAst::ExpressionPtr Parser::ParseExpression() { return ParseBinOpRhs(0, ParsePrimaryExpression()); } - ShaderNodes::ExpressionPtr Parser::ParseIdentifier() + ShaderAst::ExpressionPtr Parser::ParseIdentifier() { const Token& identifier = ExpectNext(TokenType::Identifier); - return ShaderNodes::Identifier::Build(ShaderNodes::ParameterVariable::Build(std::get(identifier.data), ShaderNodes::BasicType::Float3)); + return ShaderBuilder::Identifier(std::get(identifier.data)); } - ShaderNodes::ExpressionPtr Parser::ParseIntegerExpression() + ShaderAst::ExpressionPtr Parser::ParseIntegerExpression() { const Token& integer = ExpectNext(TokenType::IntegerValue); - return ShaderNodes::Constant::Build(static_cast(std::get(integer.data))); + return ShaderBuilder::Constant(static_cast(std::get(integer.data))); } - ShaderNodes::ExpressionPtr Parser::ParseParenthesisExpression() + ShaderAst::ExpressionPtr Parser::ParseParenthesisExpression() { ExpectNext(TokenType::OpenParenthesis); - ShaderNodes::ExpressionPtr expression = ParseExpression(); + ShaderAst::ExpressionPtr expression = ParseExpression(); ExpectNext(TokenType::ClosingParenthesis); return expression; } - ShaderNodes::ExpressionPtr Parser::ParsePrimaryExpression() + ShaderAst::ExpressionPtr Parser::ParsePrimaryExpression() { const Token& token = PeekNext(); switch (token.type) { - case TokenType::BoolFalse: return ShaderNodes::Constant::Build(false); - case TokenType::BoolTrue: return ShaderNodes::Constant::Build(true); + case TokenType::BoolFalse: return ShaderBuilder::Constant(false); + case TokenType::BoolTrue: return ShaderBuilder::Constant(true); + case TokenType::FloatingPointValue: return ShaderBuilder::Constant(float(std::get(Advance().data))); //< FIXME case TokenType::Identifier: return ParseIdentifier(); case TokenType::IntegerValue: return ParseIntegerExpression(); case TokenType::OpenParenthesis: return ParseParenthesisExpression(); @@ -286,7 +314,7 @@ namespace Nz::ShaderLang return identifier; } - ShaderExpressionType Parser::ParseIdentifierAsType() + ShaderAst::ShaderExpressionType Parser::ParseIdentifierAsType() { const Token& identifier = ExpectNext(TokenType::Identifier); diff --git a/src/Nazara/Shader/ShaderNodes.cpp b/src/Nazara/Shader/ShaderNodes.cpp index 97a21506b..d4510fcfc 100644 --- a/src/Nazara/Shader/ShaderNodes.cpp +++ b/src/Nazara/Shader/ShaderNodes.cpp @@ -4,265 +4,29 @@ #include #include -#include -#include -#include +#include +#include #include -namespace Nz::ShaderNodes +namespace Nz::ShaderAst { Node::~Node() = default; - void ExpressionStatement::Visit(ShaderAstVisitor& visitor) - { - visitor.Visit(*this); +#define NAZARA_SHADERAST_NODE(Node) NodeType Node::GetType() const \ + { \ + return NodeType:: Node; \ + } +#include + +#define NAZARA_SHADERAST_EXPRESSION(Node) void Node::Visit(AstExpressionVisitor& visitor) \ + {\ + visitor.Visit(*this); \ } - - void ConditionalStatement::Visit(ShaderAstVisitor& visitor) - { - visitor.Visit(*this); +#define NAZARA_SHADERAST_STATEMENT(Node) void Node::Visit(AstStatementVisitor& visitor) \ + {\ + visitor.Visit(*this); \ } - - void StatementBlock::Visit(ShaderAstVisitor& visitor) - { - visitor.Visit(*this); - } - - - void DeclareVariable::Visit(ShaderAstVisitor& visitor) - { - visitor.Visit(*this); - } - - - void Discard::Visit(ShaderAstVisitor& visitor) - { - visitor.Visit(*this); - } - - - ShaderExpressionType Identifier::GetExpressionType() const - { - assert(var); - return var->type; - } - - void Identifier::Visit(ShaderAstVisitor& visitor) - { - visitor.Visit(*this); - } - - ShaderExpressionType AccessMember::GetExpressionType() const - { - return exprType; - } - - void AccessMember::Visit(ShaderAstVisitor& visitor) - { - visitor.Visit(*this); - } - - void NoOp::Visit(ShaderAstVisitor& visitor) - { - visitor.Visit(*this); - } - - void ReturnStatement::Visit(ShaderAstVisitor& visitor) - { - visitor.Visit(*this); - } - - ShaderExpressionType AssignOp::GetExpressionType() const - { - return left->GetExpressionType(); - } - - void AssignOp::Visit(ShaderAstVisitor& visitor) - { - visitor.Visit(*this); - } - - - ShaderExpressionType BinaryOp::GetExpressionType() const - { - std::optional exprType; - - switch (op) - { - case BinaryType::Add: - case BinaryType::Subtract: - exprType = left->GetExpressionType(); - break; - - case BinaryType::Divide: - case BinaryType::Multiply: - { - const ShaderExpressionType& leftExprType = left->GetExpressionType(); - assert(IsBasicType(leftExprType)); - - const ShaderExpressionType& rightExprType = right->GetExpressionType(); - assert(IsBasicType(rightExprType)); - - switch (std::get(leftExprType)) - { - case BasicType::Boolean: - case BasicType::Float2: - case BasicType::Float3: - case BasicType::Float4: - case BasicType::Int2: - case BasicType::Int3: - case BasicType::Int4: - case BasicType::UInt2: - case BasicType::UInt3: - case BasicType::UInt4: - exprType = leftExprType; - break; - - case BasicType::Float1: - case BasicType::Int1: - case BasicType::Mat4x4: - case BasicType::UInt1: - exprType = rightExprType; - break; - - case BasicType::Sampler2D: - case BasicType::Void: - break; - } - - break; - } - - case BinaryType::CompEq: - case BinaryType::CompGe: - case BinaryType::CompGt: - case BinaryType::CompLe: - case BinaryType::CompLt: - case BinaryType::CompNe: - exprType = BasicType::Boolean; - break; - } - - NazaraAssert(exprType.has_value(), "Unhandled builtin"); - - return *exprType; - } - - void BinaryOp::Visit(ShaderAstVisitor& visitor) - { - visitor.Visit(*this); - } - - - void Branch::Visit(ShaderAstVisitor& visitor) - { - visitor.Visit(*this); - } - - - ShaderExpressionType Constant::GetExpressionType() const - { - return std::visit([&](auto&& arg) - { - using T = std::decay_t; - - if constexpr (std::is_same_v) - return ShaderNodes::BasicType::Boolean; - else if constexpr (std::is_same_v) - return ShaderNodes::BasicType::Float1; - else if constexpr (std::is_same_v) - return ShaderNodes::BasicType::Int1; - else if constexpr (std::is_same_v) - return ShaderNodes::BasicType::Int1; - else if constexpr (std::is_same_v) - return ShaderNodes::BasicType::Float2; - else if constexpr (std::is_same_v) - return ShaderNodes::BasicType::Float3; - else if constexpr (std::is_same_v) - return ShaderNodes::BasicType::Float4; - else if constexpr (std::is_same_v) - return ShaderNodes::BasicType::Int2; - else if constexpr (std::is_same_v) - return ShaderNodes::BasicType::Int3; - else if constexpr (std::is_same_v) - return ShaderNodes::BasicType::Int4; - else - static_assert(AlwaysFalse::value, "non-exhaustive visitor"); - }, value); - } - - void Constant::Visit(ShaderAstVisitor& visitor) - { - visitor.Visit(*this); - } - - ShaderExpressionType Cast::GetExpressionType() const - { - return exprType; - } - - void Cast::Visit(ShaderAstVisitor& visitor) - { - visitor.Visit(*this); - } - - - ShaderExpressionType ConditionalExpression::GetExpressionType() const - { - assert(truePath->GetExpressionType() == falsePath->GetExpressionType()); - return truePath->GetExpressionType(); - } - - void ConditionalExpression::Visit(ShaderAstVisitor& visitor) - { - visitor.Visit(*this); - } - - - ShaderExpressionType SwizzleOp::GetExpressionType() const - { - const ShaderExpressionType& exprType = expression->GetExpressionType(); - assert(IsBasicType(exprType)); - - return static_cast(UnderlyingCast(GetComponentType(std::get(exprType))) + componentCount - 1); - } - - void SwizzleOp::Visit(ShaderAstVisitor& visitor) - { - visitor.Visit(*this); - } - - - ShaderExpressionType Sample2D::GetExpressionType() const - { - return BasicType::Float4; - } - - void Sample2D::Visit(ShaderAstVisitor& visitor) - { - visitor.Visit(*this); - } - - - ShaderExpressionType IntrinsicCall::GetExpressionType() const - { - switch (intrinsic) - { - case IntrinsicType::CrossProduct: - return parameters.front()->GetExpressionType(); - - case IntrinsicType::DotProduct: - return BasicType::Float1; - } - - NazaraAssert(false, "Unhandled builtin"); - return BasicType::Void; - } - - void IntrinsicCall::Visit(ShaderAstVisitor& visitor) - { - visitor.Visit(*this); - } +#include } diff --git a/src/Nazara/Shader/ShaderVarVisitorExcept.cpp b/src/Nazara/Shader/ShaderVarVisitorExcept.cpp deleted file mode 100644 index 3629f86d1..000000000 --- a/src/Nazara/Shader/ShaderVarVisitorExcept.cpp +++ /dev/null @@ -1,40 +0,0 @@ -// Copyright (C) 2020 Jérôme Leclercq -// This file is part of the "Nazara Engine - Shader generator" -// For conditions of distribution and use, see copyright notice in Config.hpp - -#include -#include -#include - -namespace Nz -{ - void ShaderVarVisitorExcept::Visit(ShaderNodes::BuiltinVariable& /*var*/) - { - throw std::runtime_error("unhandled BuiltinVariable"); - } - - void ShaderVarVisitorExcept::Visit(ShaderNodes::InputVariable& /*var*/) - { - throw std::runtime_error("unhandled InputVariable"); - } - - void ShaderVarVisitorExcept::Visit(ShaderNodes::LocalVariable& /*var*/) - { - throw std::runtime_error("unhandled LocalVariable"); - } - - void ShaderVarVisitorExcept::Visit(ShaderNodes::OutputVariable& /*var*/) - { - throw std::runtime_error("unhandled OutputVariable"); - } - - void ShaderVarVisitorExcept::Visit(ShaderNodes::ParameterVariable& /*var*/) - { - throw std::runtime_error("unhandled ParameterVariable"); - } - - void ShaderVarVisitorExcept::Visit(ShaderNodes::UniformVariable& /*var*/) - { - throw std::runtime_error("unhandled UniformVariable"); - } -} diff --git a/src/Nazara/Shader/ShaderVariables.cpp b/src/Nazara/Shader/ShaderVariables.cpp deleted file mode 100644 index ebe520a0c..000000000 --- a/src/Nazara/Shader/ShaderVariables.cpp +++ /dev/null @@ -1,77 +0,0 @@ -// Copyright (C) 2020 Jérôme Leclercq -// This file is part of the "Nazara Engine - Shader generator" -// For conditions of distribution and use, see copyright notice in Config.hpp - -#include -#include -#include - -namespace Nz::ShaderNodes -{ - ShaderNodes::Variable::~Variable() = default; - - VariableType BuiltinVariable::GetType() const - { - return VariableType::BuiltinVariable; - } - - void BuiltinVariable::Visit(ShaderVarVisitor& visitor) - { - visitor.Visit(*this); - } - - - VariableType InputVariable::GetType() const - { - return VariableType::InputVariable; - } - - void InputVariable::Visit(ShaderVarVisitor& visitor) - { - visitor.Visit(*this); - } - - - VariableType LocalVariable::GetType() const - { - return VariableType::LocalVariable; - } - - void LocalVariable::Visit(ShaderVarVisitor& visitor) - { - visitor.Visit(*this); - } - - - VariableType OutputVariable::GetType() const - { - return VariableType::OutputVariable; - } - - void OutputVariable::Visit(ShaderVarVisitor& visitor) - { - visitor.Visit(*this); - } - - - VariableType ParameterVariable::GetType() const - { - return VariableType::ParameterVariable; - } - - void ParameterVariable::Visit(ShaderVarVisitor& visitor) - { - visitor.Visit(*this); - } - - - VariableType UniformVariable::GetType() const - { - return VariableType::UniformVariable; - } - - void UniformVariable::Visit(ShaderVarVisitor& visitor) - { - visitor.Visit(*this); - } -} diff --git a/src/Nazara/Shader/SpirvAstVisitor.cpp b/src/Nazara/Shader/SpirvAstVisitor.cpp index 688648e70..96a099a21 100644 --- a/src/Nazara/Shader/SpirvAstVisitor.cpp +++ b/src/Nazara/Shader/SpirvAstVisitor.cpp @@ -4,6 +4,7 @@ #include #include +#include #include #include #include @@ -12,21 +13,21 @@ namespace Nz { - UInt32 SpirvAstVisitor::EvaluateExpression(const ShaderNodes::ExpressionPtr& expr) + UInt32 SpirvAstVisitor::EvaluateExpression(ShaderAst::ExpressionPtr& expr) { - Visit(expr); + expr->Visit(*this); assert(m_resultIds.size() == 1); return PopResultId(); } - void SpirvAstVisitor::Visit(ShaderNodes::AccessMember& node) + void SpirvAstVisitor::Visit(ShaderAst::AccessMemberExpression& node) { SpirvExpressionLoad accessMemberVisitor(m_writer, *m_currentBlock); PushResultId(accessMemberVisitor.Evaluate(node)); } - void SpirvAstVisitor::Visit(ShaderNodes::AssignOp& node) + void SpirvAstVisitor::Visit(ShaderAst::AssignExpression& node) { UInt32 resultId = EvaluateExpression(node.right); @@ -36,20 +37,20 @@ namespace Nz PushResultId(resultId); } - void SpirvAstVisitor::Visit(ShaderNodes::BinaryOp& node) + void SpirvAstVisitor::Visit(ShaderAst::BinaryExpression& node) { - ShaderExpressionType resultExprType = node.GetExpressionType(); + ShaderAst::ShaderExpressionType resultExprType = ShaderAst::GetExpressionType(node); assert(IsBasicType(resultExprType)); - const ShaderExpressionType& leftExprType = node.left->GetExpressionType(); + ShaderAst::ShaderExpressionType leftExprType = ShaderAst::GetExpressionType(*node.left); assert(IsBasicType(leftExprType)); - const ShaderExpressionType& rightExprType = node.right->GetExpressionType(); + ShaderAst::ShaderExpressionType rightExprType = ShaderAst::GetExpressionType(*node.right); assert(IsBasicType(rightExprType)); - ShaderNodes::BasicType resultType = std::get(resultExprType); - ShaderNodes::BasicType leftType = std::get(leftExprType); - ShaderNodes::BasicType rightType = std::get(rightExprType); + ShaderAst::BasicType resultType = std::get(resultExprType); + ShaderAst::BasicType leftType = std::get(leftExprType); + ShaderAst::BasicType rightType = std::get(rightExprType); UInt32 leftOperand = EvaluateExpression(node.left); @@ -62,308 +63,308 @@ namespace Nz { switch (node.op) { - case ShaderNodes::BinaryType::Add: + case ShaderAst::BinaryType::Add: { switch (leftType) { - case ShaderNodes::BasicType::Float1: - case ShaderNodes::BasicType::Float2: - case ShaderNodes::BasicType::Float3: - case ShaderNodes::BasicType::Float4: - case ShaderNodes::BasicType::Mat4x4: + case ShaderAst::BasicType::Float1: + case ShaderAst::BasicType::Float2: + case ShaderAst::BasicType::Float3: + case ShaderAst::BasicType::Float4: + case ShaderAst::BasicType::Mat4x4: return SpirvOp::OpFAdd; - case ShaderNodes::BasicType::Int1: - case ShaderNodes::BasicType::Int2: - case ShaderNodes::BasicType::Int3: - case ShaderNodes::BasicType::Int4: - case ShaderNodes::BasicType::UInt1: - case ShaderNodes::BasicType::UInt2: - case ShaderNodes::BasicType::UInt3: - case ShaderNodes::BasicType::UInt4: + case ShaderAst::BasicType::Int1: + case ShaderAst::BasicType::Int2: + case ShaderAst::BasicType::Int3: + case ShaderAst::BasicType::Int4: + case ShaderAst::BasicType::UInt1: + case ShaderAst::BasicType::UInt2: + case ShaderAst::BasicType::UInt3: + case ShaderAst::BasicType::UInt4: return SpirvOp::OpIAdd; - case ShaderNodes::BasicType::Boolean: - case ShaderNodes::BasicType::Sampler2D: - case ShaderNodes::BasicType::Void: + case ShaderAst::BasicType::Boolean: + case ShaderAst::BasicType::Sampler2D: + case ShaderAst::BasicType::Void: break; } break; } - case ShaderNodes::BinaryType::Subtract: + case ShaderAst::BinaryType::Subtract: { switch (leftType) { - case ShaderNodes::BasicType::Float1: - case ShaderNodes::BasicType::Float2: - case ShaderNodes::BasicType::Float3: - case ShaderNodes::BasicType::Float4: - case ShaderNodes::BasicType::Mat4x4: + case ShaderAst::BasicType::Float1: + case ShaderAst::BasicType::Float2: + case ShaderAst::BasicType::Float3: + case ShaderAst::BasicType::Float4: + case ShaderAst::BasicType::Mat4x4: return SpirvOp::OpFSub; - case ShaderNodes::BasicType::Int1: - case ShaderNodes::BasicType::Int2: - case ShaderNodes::BasicType::Int3: - case ShaderNodes::BasicType::Int4: - case ShaderNodes::BasicType::UInt1: - case ShaderNodes::BasicType::UInt2: - case ShaderNodes::BasicType::UInt3: - case ShaderNodes::BasicType::UInt4: + case ShaderAst::BasicType::Int1: + case ShaderAst::BasicType::Int2: + case ShaderAst::BasicType::Int3: + case ShaderAst::BasicType::Int4: + case ShaderAst::BasicType::UInt1: + case ShaderAst::BasicType::UInt2: + case ShaderAst::BasicType::UInt3: + case ShaderAst::BasicType::UInt4: return SpirvOp::OpISub; - case ShaderNodes::BasicType::Boolean: - case ShaderNodes::BasicType::Sampler2D: - case ShaderNodes::BasicType::Void: + case ShaderAst::BasicType::Boolean: + case ShaderAst::BasicType::Sampler2D: + case ShaderAst::BasicType::Void: break; } break; } - case ShaderNodes::BinaryType::Divide: + case ShaderAst::BinaryType::Divide: { switch (leftType) { - case ShaderNodes::BasicType::Float1: - case ShaderNodes::BasicType::Float2: - case ShaderNodes::BasicType::Float3: - case ShaderNodes::BasicType::Float4: - case ShaderNodes::BasicType::Mat4x4: + case ShaderAst::BasicType::Float1: + case ShaderAst::BasicType::Float2: + case ShaderAst::BasicType::Float3: + case ShaderAst::BasicType::Float4: + case ShaderAst::BasicType::Mat4x4: return SpirvOp::OpFDiv; - case ShaderNodes::BasicType::Int1: - case ShaderNodes::BasicType::Int2: - case ShaderNodes::BasicType::Int3: - case ShaderNodes::BasicType::Int4: + case ShaderAst::BasicType::Int1: + case ShaderAst::BasicType::Int2: + case ShaderAst::BasicType::Int3: + case ShaderAst::BasicType::Int4: return SpirvOp::OpSDiv; - case ShaderNodes::BasicType::UInt1: - case ShaderNodes::BasicType::UInt2: - case ShaderNodes::BasicType::UInt3: - case ShaderNodes::BasicType::UInt4: + case ShaderAst::BasicType::UInt1: + case ShaderAst::BasicType::UInt2: + case ShaderAst::BasicType::UInt3: + case ShaderAst::BasicType::UInt4: return SpirvOp::OpUDiv; - case ShaderNodes::BasicType::Boolean: - case ShaderNodes::BasicType::Sampler2D: - case ShaderNodes::BasicType::Void: + case ShaderAst::BasicType::Boolean: + case ShaderAst::BasicType::Sampler2D: + case ShaderAst::BasicType::Void: break; } break; } - case ShaderNodes::BinaryType::CompEq: + case ShaderAst::BinaryType::CompEq: { switch (leftType) { - case ShaderNodes::BasicType::Boolean: + case ShaderAst::BasicType::Boolean: return SpirvOp::OpLogicalEqual; - case ShaderNodes::BasicType::Float1: - case ShaderNodes::BasicType::Float2: - case ShaderNodes::BasicType::Float3: - case ShaderNodes::BasicType::Float4: - case ShaderNodes::BasicType::Mat4x4: + case ShaderAst::BasicType::Float1: + case ShaderAst::BasicType::Float2: + case ShaderAst::BasicType::Float3: + case ShaderAst::BasicType::Float4: + case ShaderAst::BasicType::Mat4x4: return SpirvOp::OpFOrdEqual; - case ShaderNodes::BasicType::Int1: - case ShaderNodes::BasicType::Int2: - case ShaderNodes::BasicType::Int3: - case ShaderNodes::BasicType::Int4: - case ShaderNodes::BasicType::UInt1: - case ShaderNodes::BasicType::UInt2: - case ShaderNodes::BasicType::UInt3: - case ShaderNodes::BasicType::UInt4: + case ShaderAst::BasicType::Int1: + case ShaderAst::BasicType::Int2: + case ShaderAst::BasicType::Int3: + case ShaderAst::BasicType::Int4: + case ShaderAst::BasicType::UInt1: + case ShaderAst::BasicType::UInt2: + case ShaderAst::BasicType::UInt3: + case ShaderAst::BasicType::UInt4: return SpirvOp::OpIEqual; - case ShaderNodes::BasicType::Sampler2D: - case ShaderNodes::BasicType::Void: + case ShaderAst::BasicType::Sampler2D: + case ShaderAst::BasicType::Void: break; } break; } - case ShaderNodes::BinaryType::CompGe: + case ShaderAst::BinaryType::CompGe: { switch (leftType) { - case ShaderNodes::BasicType::Float1: - case ShaderNodes::BasicType::Float2: - case ShaderNodes::BasicType::Float3: - case ShaderNodes::BasicType::Float4: - case ShaderNodes::BasicType::Mat4x4: + case ShaderAst::BasicType::Float1: + case ShaderAst::BasicType::Float2: + case ShaderAst::BasicType::Float3: + case ShaderAst::BasicType::Float4: + case ShaderAst::BasicType::Mat4x4: return SpirvOp::OpFOrdGreaterThan; - case ShaderNodes::BasicType::Int1: - case ShaderNodes::BasicType::Int2: - case ShaderNodes::BasicType::Int3: - case ShaderNodes::BasicType::Int4: + case ShaderAst::BasicType::Int1: + case ShaderAst::BasicType::Int2: + case ShaderAst::BasicType::Int3: + case ShaderAst::BasicType::Int4: return SpirvOp::OpSGreaterThan; - case ShaderNodes::BasicType::UInt1: - case ShaderNodes::BasicType::UInt2: - case ShaderNodes::BasicType::UInt3: - case ShaderNodes::BasicType::UInt4: + case ShaderAst::BasicType::UInt1: + case ShaderAst::BasicType::UInt2: + case ShaderAst::BasicType::UInt3: + case ShaderAst::BasicType::UInt4: return SpirvOp::OpUGreaterThan; - case ShaderNodes::BasicType::Boolean: - case ShaderNodes::BasicType::Sampler2D: - case ShaderNodes::BasicType::Void: + case ShaderAst::BasicType::Boolean: + case ShaderAst::BasicType::Sampler2D: + case ShaderAst::BasicType::Void: break; } break; } - case ShaderNodes::BinaryType::CompGt: + case ShaderAst::BinaryType::CompGt: { switch (leftType) { - case ShaderNodes::BasicType::Float1: - case ShaderNodes::BasicType::Float2: - case ShaderNodes::BasicType::Float3: - case ShaderNodes::BasicType::Float4: - case ShaderNodes::BasicType::Mat4x4: + case ShaderAst::BasicType::Float1: + case ShaderAst::BasicType::Float2: + case ShaderAst::BasicType::Float3: + case ShaderAst::BasicType::Float4: + case ShaderAst::BasicType::Mat4x4: return SpirvOp::OpFOrdGreaterThanEqual; - case ShaderNodes::BasicType::Int1: - case ShaderNodes::BasicType::Int2: - case ShaderNodes::BasicType::Int3: - case ShaderNodes::BasicType::Int4: + case ShaderAst::BasicType::Int1: + case ShaderAst::BasicType::Int2: + case ShaderAst::BasicType::Int3: + case ShaderAst::BasicType::Int4: return SpirvOp::OpSGreaterThanEqual; - case ShaderNodes::BasicType::UInt1: - case ShaderNodes::BasicType::UInt2: - case ShaderNodes::BasicType::UInt3: - case ShaderNodes::BasicType::UInt4: + case ShaderAst::BasicType::UInt1: + case ShaderAst::BasicType::UInt2: + case ShaderAst::BasicType::UInt3: + case ShaderAst::BasicType::UInt4: return SpirvOp::OpUGreaterThanEqual; - case ShaderNodes::BasicType::Boolean: - case ShaderNodes::BasicType::Sampler2D: - case ShaderNodes::BasicType::Void: + case ShaderAst::BasicType::Boolean: + case ShaderAst::BasicType::Sampler2D: + case ShaderAst::BasicType::Void: break; } break; } - case ShaderNodes::BinaryType::CompLe: + case ShaderAst::BinaryType::CompLe: { switch (leftType) { - case ShaderNodes::BasicType::Float1: - case ShaderNodes::BasicType::Float2: - case ShaderNodes::BasicType::Float3: - case ShaderNodes::BasicType::Float4: - case ShaderNodes::BasicType::Mat4x4: + case ShaderAst::BasicType::Float1: + case ShaderAst::BasicType::Float2: + case ShaderAst::BasicType::Float3: + case ShaderAst::BasicType::Float4: + case ShaderAst::BasicType::Mat4x4: return SpirvOp::OpFOrdLessThanEqual; - case ShaderNodes::BasicType::Int1: - case ShaderNodes::BasicType::Int2: - case ShaderNodes::BasicType::Int3: - case ShaderNodes::BasicType::Int4: + case ShaderAst::BasicType::Int1: + case ShaderAst::BasicType::Int2: + case ShaderAst::BasicType::Int3: + case ShaderAst::BasicType::Int4: return SpirvOp::OpSLessThanEqual; - case ShaderNodes::BasicType::UInt1: - case ShaderNodes::BasicType::UInt2: - case ShaderNodes::BasicType::UInt3: - case ShaderNodes::BasicType::UInt4: + case ShaderAst::BasicType::UInt1: + case ShaderAst::BasicType::UInt2: + case ShaderAst::BasicType::UInt3: + case ShaderAst::BasicType::UInt4: return SpirvOp::OpULessThanEqual; - case ShaderNodes::BasicType::Boolean: - case ShaderNodes::BasicType::Sampler2D: - case ShaderNodes::BasicType::Void: + case ShaderAst::BasicType::Boolean: + case ShaderAst::BasicType::Sampler2D: + case ShaderAst::BasicType::Void: break; } break; } - case ShaderNodes::BinaryType::CompLt: + case ShaderAst::BinaryType::CompLt: { switch (leftType) { - case ShaderNodes::BasicType::Float1: - case ShaderNodes::BasicType::Float2: - case ShaderNodes::BasicType::Float3: - case ShaderNodes::BasicType::Float4: - case ShaderNodes::BasicType::Mat4x4: + case ShaderAst::BasicType::Float1: + case ShaderAst::BasicType::Float2: + case ShaderAst::BasicType::Float3: + case ShaderAst::BasicType::Float4: + case ShaderAst::BasicType::Mat4x4: return SpirvOp::OpFOrdLessThan; - case ShaderNodes::BasicType::Int1: - case ShaderNodes::BasicType::Int2: - case ShaderNodes::BasicType::Int3: - case ShaderNodes::BasicType::Int4: + case ShaderAst::BasicType::Int1: + case ShaderAst::BasicType::Int2: + case ShaderAst::BasicType::Int3: + case ShaderAst::BasicType::Int4: return SpirvOp::OpSLessThan; - case ShaderNodes::BasicType::UInt1: - case ShaderNodes::BasicType::UInt2: - case ShaderNodes::BasicType::UInt3: - case ShaderNodes::BasicType::UInt4: + case ShaderAst::BasicType::UInt1: + case ShaderAst::BasicType::UInt2: + case ShaderAst::BasicType::UInt3: + case ShaderAst::BasicType::UInt4: return SpirvOp::OpULessThan; - case ShaderNodes::BasicType::Boolean: - case ShaderNodes::BasicType::Sampler2D: - case ShaderNodes::BasicType::Void: + case ShaderAst::BasicType::Boolean: + case ShaderAst::BasicType::Sampler2D: + case ShaderAst::BasicType::Void: break; } break; } - case ShaderNodes::BinaryType::CompNe: + case ShaderAst::BinaryType::CompNe: { switch (leftType) { - case ShaderNodes::BasicType::Boolean: + case ShaderAst::BasicType::Boolean: return SpirvOp::OpLogicalNotEqual; - case ShaderNodes::BasicType::Float1: - case ShaderNodes::BasicType::Float2: - case ShaderNodes::BasicType::Float3: - case ShaderNodes::BasicType::Float4: - case ShaderNodes::BasicType::Mat4x4: + case ShaderAst::BasicType::Float1: + case ShaderAst::BasicType::Float2: + case ShaderAst::BasicType::Float3: + case ShaderAst::BasicType::Float4: + case ShaderAst::BasicType::Mat4x4: return SpirvOp::OpFOrdNotEqual; - case ShaderNodes::BasicType::Int1: - case ShaderNodes::BasicType::Int2: - case ShaderNodes::BasicType::Int3: - case ShaderNodes::BasicType::Int4: - case ShaderNodes::BasicType::UInt1: - case ShaderNodes::BasicType::UInt2: - case ShaderNodes::BasicType::UInt3: - case ShaderNodes::BasicType::UInt4: + case ShaderAst::BasicType::Int1: + case ShaderAst::BasicType::Int2: + case ShaderAst::BasicType::Int3: + case ShaderAst::BasicType::Int4: + case ShaderAst::BasicType::UInt1: + case ShaderAst::BasicType::UInt2: + case ShaderAst::BasicType::UInt3: + case ShaderAst::BasicType::UInt4: return SpirvOp::OpINotEqual; - case ShaderNodes::BasicType::Sampler2D: - case ShaderNodes::BasicType::Void: + case ShaderAst::BasicType::Sampler2D: + case ShaderAst::BasicType::Void: break; } break; } - case ShaderNodes::BinaryType::Multiply: + case ShaderAst::BinaryType::Multiply: { switch (leftType) { - case ShaderNodes::BasicType::Float1: + case ShaderAst::BasicType::Float1: { switch (rightType) { - case ShaderNodes::BasicType::Float1: + case ShaderAst::BasicType::Float1: return SpirvOp::OpFMul; - case ShaderNodes::BasicType::Float2: - case ShaderNodes::BasicType::Float3: - case ShaderNodes::BasicType::Float4: + case ShaderAst::BasicType::Float2: + case ShaderAst::BasicType::Float3: + case ShaderAst::BasicType::Float4: swapOperands = true; return SpirvOp::OpVectorTimesScalar; - case ShaderNodes::BasicType::Mat4x4: + case ShaderAst::BasicType::Mat4x4: swapOperands = true; return SpirvOp::OpMatrixTimesScalar; @@ -374,21 +375,21 @@ namespace Nz break; } - case ShaderNodes::BasicType::Float2: - case ShaderNodes::BasicType::Float3: - case ShaderNodes::BasicType::Float4: + case ShaderAst::BasicType::Float2: + case ShaderAst::BasicType::Float3: + case ShaderAst::BasicType::Float4: { switch (rightType) { - case ShaderNodes::BasicType::Float1: + case ShaderAst::BasicType::Float1: return SpirvOp::OpVectorTimesScalar; - case ShaderNodes::BasicType::Float2: - case ShaderNodes::BasicType::Float3: - case ShaderNodes::BasicType::Float4: + case ShaderAst::BasicType::Float2: + case ShaderAst::BasicType::Float3: + case ShaderAst::BasicType::Float4: return SpirvOp::OpFMul; - case ShaderNodes::BasicType::Mat4x4: + case ShaderAst::BasicType::Mat4x4: return SpirvOp::OpVectorTimesMatrix; default: @@ -398,23 +399,23 @@ namespace Nz break; } - case ShaderNodes::BasicType::Int1: - case ShaderNodes::BasicType::Int2: - case ShaderNodes::BasicType::Int3: - case ShaderNodes::BasicType::Int4: - case ShaderNodes::BasicType::UInt1: - case ShaderNodes::BasicType::UInt2: - case ShaderNodes::BasicType::UInt3: - case ShaderNodes::BasicType::UInt4: + case ShaderAst::BasicType::Int1: + case ShaderAst::BasicType::Int2: + case ShaderAst::BasicType::Int3: + case ShaderAst::BasicType::Int4: + case ShaderAst::BasicType::UInt1: + case ShaderAst::BasicType::UInt2: + case ShaderAst::BasicType::UInt3: + case ShaderAst::BasicType::UInt4: return SpirvOp::OpIMul; - case ShaderNodes::BasicType::Mat4x4: + case ShaderAst::BasicType::Mat4x4: { switch (rightType) { - case ShaderNodes::BasicType::Float1: return SpirvOp::OpMatrixTimesScalar; - case ShaderNodes::BasicType::Float4: return SpirvOp::OpMatrixTimesVector; - case ShaderNodes::BasicType::Mat4x4: return SpirvOp::OpMatrixTimesMatrix; + case ShaderAst::BasicType::Float1: return SpirvOp::OpMatrixTimesScalar; + case ShaderAst::BasicType::Float4: return SpirvOp::OpMatrixTimesVector; + case ShaderAst::BasicType::Mat4x4: return SpirvOp::OpMatrixTimesMatrix; default: break; @@ -442,7 +443,7 @@ namespace Nz PushResultId(resultId); } - void SpirvAstVisitor::Visit(ShaderNodes::Branch& node) + void SpirvAstVisitor::Visit(ShaderAst::BranchStatement& node) { assert(!node.condStatements.empty()); auto& firstCond = node.condStatements.front(); @@ -450,7 +451,8 @@ namespace Nz UInt32 previousConditionId = EvaluateExpression(firstCond.condition); SpirvBlock previousContentBlock(m_writer); m_currentBlock = &previousContentBlock; - Visit(firstCond.statement); + + firstCond.statement->Visit(*this); SpirvBlock mergeBlock(m_writer); m_blocks.back().Append(SpirvOp::OpSelectionMerge, mergeBlock.GetLabelId(), SpirvSelectionControl::None); @@ -458,7 +460,7 @@ namespace Nz std::optional nextBlock; for (std::size_t statementIndex = 1; statementIndex < node.condStatements.size(); ++statementIndex) { - const auto& statement = node.condStatements[statementIndex]; + auto& statement = node.condStatements[statementIndex]; SpirvBlock contentBlock(m_writer); @@ -469,7 +471,8 @@ namespace Nz previousContentBlock = std::move(contentBlock); m_currentBlock = &previousContentBlock; - Visit(statement.statement); + + statement.statement->Visit(*this); } if (node.elseStatement) @@ -477,7 +480,7 @@ namespace Nz SpirvBlock elseBlock(m_writer); m_currentBlock = &elseBlock; - Visit(node.elseStatement); + node.elseStatement->Visit(*this); elseBlock.Append(SpirvOp::OpBranch, mergeBlock.GetLabelId()); //< FIXME: Shouldn't terminate twice @@ -496,16 +499,16 @@ namespace Nz m_currentBlock = &m_blocks.back(); } - void SpirvAstVisitor::Visit(ShaderNodes::Cast& node) + void SpirvAstVisitor::Visit(ShaderAst::CastExpression& node) { - const ShaderExpressionType& targetExprType = node.exprType; + const ShaderAst::ShaderExpressionType& targetExprType = node.targetType; assert(IsBasicType(targetExprType)); - ShaderNodes::BasicType targetType = std::get(targetExprType); + ShaderAst::BasicType targetType = std::get(targetExprType); StackVector exprResults = NazaraStackVector(UInt32, node.expressions.size()); - for (const auto& exprPtr : node.expressions) + for (auto& exprPtr : node.expressions) { if (!exprPtr) break; @@ -527,21 +530,21 @@ namespace Nz PushResultId(resultId); } - void SpirvAstVisitor::Visit(ShaderNodes::ConditionalExpression& node) + void SpirvAstVisitor::Visit(ShaderAst::ConditionalExpression& node) { if (m_writer.IsConditionEnabled(node.conditionName)) - Visit(node.truePath); + node.truePath->Visit(*this); else - Visit(node.falsePath); + node.falsePath->Visit(*this); } - void SpirvAstVisitor::Visit(ShaderNodes::ConditionalStatement& node) + void SpirvAstVisitor::Visit(ShaderAst::ConditionalStatement& node) { if (m_writer.IsConditionEnabled(node.conditionName)) - Visit(node.statement); + node.statement->Visit(*this); } - void SpirvAstVisitor::Visit(ShaderNodes::Constant& node) + void SpirvAstVisitor::Visit(ShaderAst::ConstantExpression& node) { std::visit([&] (const auto& value) { @@ -549,46 +552,42 @@ namespace Nz }, node.value); } - void SpirvAstVisitor::Visit(ShaderNodes::DeclareVariable& node) + void SpirvAstVisitor::Visit(ShaderAst::DeclareVariableStatement& node) { - if (node.expression) - { - assert(node.variable->GetType() == ShaderNodes::VariableType::LocalVariable); - - const auto& localVar = static_cast(*node.variable); - m_writer.WriteLocalVariable(localVar.name, EvaluateExpression(node.expression)); - } + if (node.initialExpression) + m_writer.WriteLocalVariable(node.varName, EvaluateExpression(node.initialExpression)); } - void SpirvAstVisitor::Visit(ShaderNodes::Discard& /*node*/) + void SpirvAstVisitor::Visit(ShaderAst::DiscardStatement& /*node*/) { m_currentBlock->Append(SpirvOp::OpKill); } - void SpirvAstVisitor::Visit(ShaderNodes::ExpressionStatement& node) + void SpirvAstVisitor::Visit(ShaderAst::ExpressionStatement& node) { - Visit(node.expression); + node.expression->Visit(*this); + PopResultId(); } - void SpirvAstVisitor::Visit(ShaderNodes::Identifier& node) + void SpirvAstVisitor::Visit(ShaderAst::IdentifierExpression& node) { SpirvExpressionLoad loadVisitor(m_writer, *m_currentBlock); PushResultId(loadVisitor.Evaluate(node)); } - void SpirvAstVisitor::Visit(ShaderNodes::IntrinsicCall& node) + void SpirvAstVisitor::Visit(ShaderAst::IntrinsicExpression& node) { switch (node.intrinsic) { - case ShaderNodes::IntrinsicType::DotProduct: + case ShaderAst::IntrinsicType::DotProduct: { - const ShaderExpressionType& vecExprType = node.parameters[0]->GetExpressionType(); + const ShaderAst::ShaderExpressionType& vecExprType = GetExpressionType(*node.parameters[0]); assert(IsBasicType(vecExprType)); - ShaderNodes::BasicType vecType = std::get(vecExprType); + ShaderAst::BasicType vecType = std::get(vecExprType); - UInt32 typeId = m_writer.GetTypeId(node.GetComponentType(vecType)); + UInt32 typeId = m_writer.GetTypeId(ShaderAst::GetComponentType(vecType)); UInt32 vec1 = EvaluateExpression(node.parameters[0]); UInt32 vec2 = EvaluateExpression(node.parameters[1]); @@ -600,18 +599,18 @@ namespace Nz break; } - case ShaderNodes::IntrinsicType::CrossProduct: + case ShaderAst::IntrinsicType::CrossProduct: default: throw std::runtime_error("not yet implemented"); } } - void SpirvAstVisitor::Visit(ShaderNodes::NoOp& /*node*/) + void SpirvAstVisitor::Visit(ShaderAst::NoOpStatement& /*node*/) { // nothing to do } - void SpirvAstVisitor::Visit(ShaderNodes::ReturnStatement& node) + void SpirvAstVisitor::Visit(ShaderAst::ReturnStatement& node) { if (node.returnExpr) m_currentBlock->Append(SpirvOp::OpReturnValue, EvaluateExpression(node.returnExpr)); @@ -619,30 +618,18 @@ namespace Nz m_currentBlock->Append(SpirvOp::OpReturn); } - void SpirvAstVisitor::Visit(ShaderNodes::Sample2D& node) - { - UInt32 typeId = m_writer.GetTypeId(ShaderNodes::BasicType::Float4); - - UInt32 samplerId = EvaluateExpression(node.sampler); - UInt32 coordinatesId = EvaluateExpression(node.coordinates); - UInt32 resultId = m_writer.AllocateResultId(); - - m_currentBlock->Append(SpirvOp::OpImageSampleImplicitLod, typeId, resultId, samplerId, coordinatesId); - PushResultId(resultId); - } - - void SpirvAstVisitor::Visit(ShaderNodes::StatementBlock& node) + void SpirvAstVisitor::Visit(ShaderAst::MultiStatement& node) { for (auto& statement : node.statements) - Visit(statement); + statement->Visit(*this); } - void SpirvAstVisitor::Visit(ShaderNodes::SwizzleOp& node) + void SpirvAstVisitor::Visit(ShaderAst::SwizzleExpression& node) { - const ShaderExpressionType& targetExprType = node.GetExpressionType(); + const ShaderAst::ShaderExpressionType& targetExprType = ShaderAst::GetExpressionType(node); assert(IsBasicType(targetExprType)); - ShaderNodes::BasicType targetType = std::get(targetExprType); + ShaderAst::BasicType targetType = std::get(targetExprType); UInt32 exprResultId = EvaluateExpression(node.expression); UInt32 resultId = m_writer.AllocateResultId(); @@ -666,7 +653,7 @@ namespace Nz // Extract a single component from the vector assert(node.componentCount == 1); - m_currentBlock->Append(SpirvOp::OpCompositeExtract, m_writer.GetTypeId(targetType), resultId, exprResultId, UInt32(node.components[0]) - UInt32(ShaderNodes::SwizzleComponent::First) ); + m_currentBlock->Append(SpirvOp::OpCompositeExtract, m_writer.GetTypeId(targetType), resultId, exprResultId, UInt32(node.components[0]) - UInt32(ShaderAst::SwizzleComponent::First) ); } PushResultId(resultId); diff --git a/src/Nazara/Shader/SpirvConstantCache.cpp b/src/Nazara/Shader/SpirvConstantCache.cpp index fd27b2dc6..05108f1f6 100644 --- a/src/Nazara/Shader/SpirvConstantCache.cpp +++ b/src/Nazara/Shader/SpirvConstantCache.cpp @@ -3,7 +3,6 @@ // For conditions of distribution and use, see copyright notice in Config.hpp #include -#include #include #include #include @@ -536,7 +535,7 @@ namespace Nz else if constexpr (std::is_same_v || std::is_same_v) { return ConstantComposite{ - BuildType((std::is_same_v) ? ShaderNodes::BasicType::Float2 : ShaderNodes::BasicType::Int2), + BuildType((std::is_same_v) ? ShaderAst::BasicType::Float2 : ShaderAst::BasicType::Int2), { BuildConstant(arg.x), BuildConstant(arg.y) @@ -546,7 +545,7 @@ namespace Nz else if constexpr (std::is_same_v || std::is_same_v) { return ConstantComposite{ - BuildType((std::is_same_v) ? ShaderNodes::BasicType::Float3 : ShaderNodes::BasicType::Int3), + BuildType((std::is_same_v) ? ShaderAst::BasicType::Float3 : ShaderAst::BasicType::Int3), { BuildConstant(arg.x), BuildConstant(arg.y), @@ -557,7 +556,7 @@ namespace Nz else if constexpr (std::is_same_v || std::is_same_v) { return ConstantComposite{ - BuildType((std::is_same_v) ? ShaderNodes::BasicType::Float4 : ShaderNodes::BasicType::Int4), + BuildType((std::is_same_v) ? ShaderAst::BasicType::Float4 : ShaderAst::BasicType::Int4), { BuildConstant(arg.x), BuildConstant(arg.y), @@ -571,7 +570,7 @@ namespace Nz }, value)); } - auto SpirvConstantCache::BuildPointerType(const ShaderNodes::BasicType& type, SpirvStorageClass storageClass) -> TypePtr + auto SpirvConstantCache::BuildPointerType(const ShaderAst::BasicType& type, SpirvStorageClass storageClass) -> TypePtr { return std::make_shared(SpirvConstantCache::Pointer{ SpirvConstantCache::BuildType(type), @@ -579,55 +578,55 @@ namespace Nz }); } - auto SpirvConstantCache::BuildPointerType(const ShaderAst& shader, const ShaderExpressionType& type, SpirvStorageClass storageClass) -> TypePtr + auto SpirvConstantCache::BuildPointerType(const ShaderAst::ShaderExpressionType& type, SpirvStorageClass storageClass) -> TypePtr { return std::make_shared(SpirvConstantCache::Pointer{ - SpirvConstantCache::BuildType(shader, type), + SpirvConstantCache::BuildType(type), storageClass }); } - auto SpirvConstantCache::BuildType(const ShaderNodes::BasicType& type) -> TypePtr + auto SpirvConstantCache::BuildType(const ShaderAst::BasicType& type) -> TypePtr { return std::make_shared([&]() -> AnyType { switch (type) { - case ShaderNodes::BasicType::Boolean: + case ShaderAst::BasicType::Boolean: return Bool{}; - case ShaderNodes::BasicType::Float1: + case ShaderAst::BasicType::Float1: return Float{ 32 }; - case ShaderNodes::BasicType::Int1: + case ShaderAst::BasicType::Int1: return Integer{ 32, true }; - case ShaderNodes::BasicType::Float2: - case ShaderNodes::BasicType::Float3: - case ShaderNodes::BasicType::Float4: - case ShaderNodes::BasicType::Int2: - case ShaderNodes::BasicType::Int3: - case ShaderNodes::BasicType::Int4: - case ShaderNodes::BasicType::UInt2: - case ShaderNodes::BasicType::UInt3: - case ShaderNodes::BasicType::UInt4: + case ShaderAst::BasicType::Float2: + case ShaderAst::BasicType::Float3: + case ShaderAst::BasicType::Float4: + case ShaderAst::BasicType::Int2: + case ShaderAst::BasicType::Int3: + case ShaderAst::BasicType::Int4: + case ShaderAst::BasicType::UInt2: + case ShaderAst::BasicType::UInt3: + case ShaderAst::BasicType::UInt4: { - auto vecType = BuildType(ShaderNodes::Node::GetComponentType(type)); - UInt32 componentCount = ShaderNodes::Node::GetComponentCount(type); + auto vecType = BuildType(ShaderAst::GetComponentType(type)); + UInt32 componentCount = ShaderAst::GetComponentCount(type); return Vector{ vecType, componentCount }; } - case ShaderNodes::BasicType::Mat4x4: - return Matrix{ BuildType(ShaderNodes::BasicType::Float4), 4u }; + case ShaderAst::BasicType::Mat4x4: + return Matrix{ BuildType(ShaderAst::BasicType::Float4), 4u }; - case ShaderNodes::BasicType::UInt1: + case ShaderAst::BasicType::UInt1: return Integer{ 32, false }; - case ShaderNodes::BasicType::Void: + case ShaderAst::BasicType::Void: return Void{}; - case ShaderNodes::BasicType::Sampler2D: + case ShaderAst::BasicType::Sampler2D: { auto imageType = Image{ {}, //< qualifier @@ -635,7 +634,7 @@ namespace Nz {}, //< sampled SpirvDim::Dim2D, //< dim SpirvImageFormat::Unknown, //< format - BuildType(ShaderNodes::BasicType::Float1), //< sampledType + BuildType(ShaderAst::BasicType::Float1), //< sampledType false, //< arrayed, false //< multisampled }; @@ -648,16 +647,16 @@ namespace Nz }()); } - auto SpirvConstantCache::BuildType(const ShaderAst& shader, const ShaderExpressionType& type) -> TypePtr + auto SpirvConstantCache::BuildType(const ShaderAst::ShaderExpressionType& type) -> TypePtr { return std::visit([&](auto&& arg) -> TypePtr { using T = std::decay_t; - if constexpr (std::is_same_v) + if constexpr (std::is_same_v) return BuildType(arg); else if constexpr (std::is_same_v) { - // Register struct members type + /*// Register struct members type const auto& structs = shader.GetStructs(); auto it = std::find_if(structs.begin(), structs.end(), [&](const auto& s) { return s.name == arg; }); if (it == structs.end()) @@ -675,7 +674,8 @@ namespace Nz sMembers.type = BuildType(shader, member.type); } - return std::make_shared(std::move(sType)); + return std::make_shared(std::move(sType));*/ + return nullptr; } else static_assert(AlwaysFalse::value, "non-exhaustive visitor"); diff --git a/src/Nazara/Shader/SpirvExpressionLoad.cpp b/src/Nazara/Shader/SpirvExpressionLoad.cpp index caddbec43..068280bd1 100644 --- a/src/Nazara/Shader/SpirvExpressionLoad.cpp +++ b/src/Nazara/Shader/SpirvExpressionLoad.cpp @@ -16,7 +16,7 @@ namespace Nz template overloaded(Ts...) -> overloaded; } - UInt32 SpirvExpressionLoad::Evaluate(ShaderNodes::Expression& node) + UInt32 SpirvExpressionLoad::Evaluate(ShaderAst::Expression& node) { node.Visit(*this); @@ -41,7 +41,7 @@ namespace Nz }, m_value); } - void SpirvExpressionLoad::Visit(ShaderNodes::AccessMember& node) + /*void SpirvExpressionLoad::Visit(ShaderAst::AccessMemberExpression& node) { Visit(node.structExpr); @@ -49,6 +49,8 @@ namespace Nz { [&](const Pointer& pointer) { + ShaderAst::ShaderExpressionType exprType = GetExpressionType(node.structExpr); + UInt32 resultId = m_writer.AllocateResultId(); UInt32 pointerType = m_writer.RegisterPointerType(node.exprType, pointer.storage); //< FIXME UInt32 typeId = m_writer.GetTypeId(node.exprType); @@ -87,40 +89,15 @@ namespace Nz throw std::runtime_error("an internal error occurred"); } }, m_value); - } + }*/ - void SpirvExpressionLoad::Visit(ShaderNodes::Identifier& node) + void SpirvExpressionLoad::Visit(ShaderAst::IdentifierExpression& node) { - Visit(node.var); - } - - void SpirvExpressionLoad::Visit(ShaderNodes::InputVariable& var) - { - auto inputVar = m_writer.GetInputVariable(var.name); - - if (auto resultIdOpt = m_writer.ReadVariable(inputVar, SpirvWriter::OnlyCache{})) - m_value = Value{ *resultIdOpt }; + if (node.identifier == "d") + m_value = Value{ m_writer.ReadLocalVariable(node.identifier) }; else - m_value = Pointer{ SpirvStorageClass::Input, inputVar.varId, inputVar.typeId }; - } + m_value = Value{ m_writer.ReadParameterVariable(node.identifier) }; - void SpirvExpressionLoad::Visit(ShaderNodes::LocalVariable& var) - { - m_value = Value{ m_writer.ReadLocalVariable(var.name) }; - } - - void SpirvExpressionLoad::Visit(ShaderNodes::ParameterVariable& var) - { - m_value = Value{ m_writer.ReadParameterVariable(var.name) }; - } - - void SpirvExpressionLoad::Visit(ShaderNodes::UniformVariable& var) - { - auto uniformVar = m_writer.GetUniformVariable(var.name); - - if (auto resultIdOpt = m_writer.ReadVariable(uniformVar, SpirvWriter::OnlyCache{})) - m_value = Value{ *resultIdOpt }; - else - m_value = Pointer{ SpirvStorageClass::Uniform, uniformVar.varId, uniformVar.typeId }; + //Visit(node.var); } } diff --git a/src/Nazara/Shader/SpirvExpressionStore.cpp b/src/Nazara/Shader/SpirvExpressionStore.cpp index a0c5511d1..8655b3a94 100644 --- a/src/Nazara/Shader/SpirvExpressionStore.cpp +++ b/src/Nazara/Shader/SpirvExpressionStore.cpp @@ -15,9 +15,9 @@ namespace Nz template overloaded(Ts...)->overloaded; } - void SpirvExpressionStore::Store(const ShaderNodes::ExpressionPtr& node, UInt32 resultId) + void SpirvExpressionStore::Store(ShaderAst::ExpressionPtr& node, UInt32 resultId) { - Visit(node); + node->Visit(*this); std::visit(overloaded { @@ -36,7 +36,7 @@ namespace Nz }, m_value); } - void SpirvExpressionStore::Visit(ShaderNodes::AccessMember& node) + /*void SpirvExpressionStore::Visit(ShaderAst::AccessMemberExpression& node) { Visit(node.structExpr); @@ -70,34 +70,15 @@ namespace Nz throw std::runtime_error("an internal error occurred"); } }, m_value); - } + }*/ - void SpirvExpressionStore::Visit(ShaderNodes::Identifier& node) + void SpirvExpressionStore::Visit(ShaderAst::IdentifierExpression& node) { - Visit(node.var); + m_value = LocalVar{ node.identifier }; } - void SpirvExpressionStore::Visit(ShaderNodes::SwizzleOp& node) + void SpirvExpressionStore::Visit(ShaderAst::SwizzleExpression& node) { throw std::runtime_error("not yet implemented"); } - - void SpirvExpressionStore::Visit(ShaderNodes::BuiltinVariable& var) - { - const auto& outputVar = m_writer.GetBuiltinVariable(var.entry); - - m_value = Pointer{ SpirvStorageClass::Output, outputVar.varId }; - } - - void SpirvExpressionStore::Visit(ShaderNodes::LocalVariable& var) - { - m_value = LocalVar{ var.name }; - } - - void SpirvExpressionStore::Visit(ShaderNodes::OutputVariable& var) - { - const auto& outputVar = m_writer.GetOutputVariable(var.name); - - m_value = Pointer{ SpirvStorageClass::Output, outputVar.varId }; - } } diff --git a/src/Nazara/Shader/SpirvWriter.cpp b/src/Nazara/Shader/SpirvWriter.cpp index 6316e42ee..d108d7efb 100644 --- a/src/Nazara/Shader/SpirvWriter.cpp +++ b/src/Nazara/Shader/SpirvWriter.cpp @@ -26,155 +26,131 @@ namespace Nz { namespace { - class PreVisitor : public ShaderAstRecursiveVisitor, public ShaderVarVisitor + class PreVisitor : public ShaderAst::AstRecursiveVisitor { public: - using BuiltinContainer = std::unordered_set>; using ExtInstList = std::unordered_set; - using LocalContainer = std::unordered_set>; - using ParameterContainer = std::unordered_set< std::shared_ptr>; + using LocalContainer = std::unordered_set; - PreVisitor(const ShaderAst& shader, const SpirvWriter::States& conditions, SpirvConstantCache& constantCache) : - m_shader(shader), + PreVisitor(ShaderAst::AstCache* cache, const SpirvWriter::States& conditions, SpirvConstantCache& constantCache) : + m_cache(cache), m_conditions(conditions), m_constantCache(constantCache) { } - using ShaderAstRecursiveVisitor::Visit; - using ShaderVarVisitor::Visit; - - void Visit(ShaderNodes::AccessMember& node) override + void Visit(ShaderAst::AccessMemberExpression& node) override { - for (std::size_t index : node.memberIndices) - m_constantCache.Register(*SpirvConstantCache::BuildConstant(Int32(index))); + /*for (std::size_t index : node.memberIdentifiers) + m_constantCache.Register(*SpirvConstantCache::BuildConstant(Int32(index)));*/ - ShaderAstRecursiveVisitor::Visit(node); + AstRecursiveVisitor::Visit(node); } - void Visit(ShaderNodes::ConditionalExpression& node) override + void Visit(ShaderAst::ConditionalExpression& node) override { - std::size_t conditionIndex = m_shader.FindConditionByName(node.conditionName); + /*std::size_t conditionIndex = m_shader.FindConditionByName(node.conditionName); assert(conditionIndex != ShaderAst::InvalidCondition); if (TestBit(m_conditions.enabledConditions, conditionIndex)) Visit(node.truePath); else - Visit(node.falsePath); + Visit(node.falsePath);*/ } - void Visit(ShaderNodes::ConditionalStatement& node) override + void Visit(ShaderAst::ConditionalStatement& node) override { - std::size_t conditionIndex = m_shader.FindConditionByName(node.conditionName); + /*std::size_t conditionIndex = m_shader.FindConditionByName(node.conditionName); assert(conditionIndex != ShaderAst::InvalidCondition); if (TestBit(m_conditions.enabledConditions, conditionIndex)) - Visit(node.statement); + Visit(node.statement);*/ } - void Visit(ShaderNodes::Constant& node) override + void Visit(ShaderAst::ConstantExpression& node) override { std::visit([&](auto&& arg) { m_constantCache.Register(*SpirvConstantCache::BuildConstant(arg)); }, node.value); - ShaderAstRecursiveVisitor::Visit(node); + AstRecursiveVisitor::Visit(node); } - void Visit(ShaderNodes::DeclareVariable& node) override + void Visit(ShaderAst::DeclareFunctionStatement& node) override { - Visit(node.variable); - - ShaderAstRecursiveVisitor::Visit(node); + m_constantCache.Register(*SpirvConstantCache::BuildType(node.returnType)); + for (auto& parameter : node.parameters) + m_constantCache.Register(*SpirvConstantCache::BuildType(parameter.type)); } - void Visit(ShaderNodes::Identifier& node) override + void Visit(ShaderAst::DeclareStructStatement& node) override { - Visit(node.var); - - ShaderAstRecursiveVisitor::Visit(node); + for (auto& field : node.description.members) + m_constantCache.Register(*SpirvConstantCache::BuildType(field.type)); } - void Visit(ShaderNodes::IntrinsicCall& node) override + void Visit(ShaderAst::DeclareVariableStatement& node) override { - ShaderAstRecursiveVisitor::Visit(node); + variableTypes.insert(node.varType); + + AstRecursiveVisitor::Visit(node); + } + + void Visit(ShaderAst::IdentifierExpression& node) override + { + variableTypes.insert(GetExpressionType(node, m_cache)); + + AstRecursiveVisitor::Visit(node); + } + + void Visit(ShaderAst::IntrinsicExpression& node) override + { + AstRecursiveVisitor::Visit(node); switch (node.intrinsic) { // Require GLSL.std.450 - case ShaderNodes::IntrinsicType::CrossProduct: + case ShaderAst::IntrinsicType::CrossProduct: extInsts.emplace("GLSL.std.450"); break; // Part of SPIR-V core - case ShaderNodes::IntrinsicType::DotProduct: + case ShaderAst::IntrinsicType::DotProduct: break; } } - void Visit(ShaderNodes::BuiltinVariable& var) override - { - builtinVars.insert(std::static_pointer_cast(var.shared_from_this())); - } - - void Visit(ShaderNodes::InputVariable& /*var*/) override - { - /* Handled by ShaderAst */ - } - - void Visit(ShaderNodes::LocalVariable& var) override - { - localVars.insert(std::static_pointer_cast(var.shared_from_this())); - } - - void Visit(ShaderNodes::OutputVariable& /*var*/) override - { - /* Handled by ShaderAst */ - } - - void Visit(ShaderNodes::ParameterVariable& var) override - { - paramVars.insert(std::static_pointer_cast(var.shared_from_this())); - } - - void Visit(ShaderNodes::UniformVariable& /*var*/) override - { - /* Handled by ShaderAst */ - } - - BuiltinContainer builtinVars; ExtInstList extInsts; - LocalContainer localVars; - ParameterContainer paramVars; + LocalContainer variableTypes; private: - const ShaderAst& m_shader; + ShaderAst::AstCache* m_cache; const SpirvWriter::States& m_conditions; SpirvConstantCache& m_constantCache; }; template - constexpr ShaderNodes::BasicType GetBasicType() + constexpr ShaderAst::BasicType GetBasicType() { if constexpr (std::is_same_v) - return ShaderNodes::BasicType::Boolean; + return ShaderAst::BasicType::Boolean; else if constexpr (std::is_same_v) - return(ShaderNodes::BasicType::Float1); + return(ShaderAst::BasicType::Float1); else if constexpr (std::is_same_v) - return(ShaderNodes::BasicType::Int1); + return(ShaderAst::BasicType::Int1); else if constexpr (std::is_same_v) - return(ShaderNodes::BasicType::Float2); + return(ShaderAst::BasicType::Float2); else if constexpr (std::is_same_v) - return(ShaderNodes::BasicType::Float3); + return(ShaderAst::BasicType::Float3); else if constexpr (std::is_same_v) - return(ShaderNodes::BasicType::Float4); + return(ShaderAst::BasicType::Float4); else if constexpr (std::is_same_v) - return(ShaderNodes::BasicType::Int2); + return(ShaderAst::BasicType::Int2); else if constexpr (std::is_same_v) - return(ShaderNodes::BasicType::Int3); + return(ShaderAst::BasicType::Int3); else if constexpr (std::is_same_v) - return(ShaderNodes::BasicType::Int4); + return(ShaderAst::BasicType::Int4); else static_assert(AlwaysFalse::value, "unhandled type"); } @@ -198,7 +174,7 @@ namespace Nz tsl::ordered_map parameterIds; tsl::ordered_map uniformIds; std::unordered_map extensionInstructions; - std::unordered_map builtinIds; + std::unordered_map builtinIds; std::unordered_map varToResult; std::vector funcs; std::vector functionBlocks; @@ -219,13 +195,12 @@ namespace Nz { } - std::vector SpirvWriter::Generate(const ShaderAst& shader, const States& conditions) + std::vector SpirvWriter::Generate(ShaderAst::StatementPtr& shader, const States& conditions) { std::string error; - if (!ValidateShader(shader, &error)) + if (!ShaderAst::ValidateAst(shader, &error, &m_context.cache)) throw std::runtime_error("Invalid shader AST: " + error); - m_context.shader = &shader; m_context.states = &conditions; State state; @@ -235,23 +210,19 @@ namespace Nz m_currentState = nullptr; }); - std::vector functionStatements; + std::vector functionStatements; - ShaderAstCloner cloner; - - PreVisitor preVisitor(shader, conditions, state.constantTypeCache); - for (const auto& func : shader.GetFunctions()) - { - functionStatements.emplace_back(cloner.Clone(func.statement)); - preVisitor.Visit(func.statement); - } + ShaderAst::AstCloner cloner; // Register all extended instruction sets + PreVisitor preVisitor(&m_context.cache, conditions, state.constantTypeCache); + shader->Visit(preVisitor); + for (const std::string& extInst : preVisitor.extInsts) state.extensionInstructions[extInst] = AllocateResultId(); // Register all types - for (const auto& func : shader.GetFunctions()) + /*for (const auto& func : shader.GetFunctions()) { RegisterType(func.returnType); for (const auto& param : func.parameters) @@ -270,8 +241,8 @@ namespace Nz for (const auto& func : shader.GetFunctions()) RegisterFunctionType(func.returnType, func.parameters); - for (const auto& local : preVisitor.localVars) - RegisterType(local->type); + for (const auto& type : preVisitor.variableTypes) + RegisterType(type); for (const auto& builtin : preVisitor.builtinVars) RegisterType(builtin->type); @@ -283,7 +254,7 @@ namespace Nz SpirvBuiltIn builtinDecoration; switch (builtin->entry) { - case ShaderNodes::BuiltinEntry::VertexPosition: + case ShaderAst::BuiltinEntry::VertexPosition: variable.debugName = "builtin_VertexPosition"; variable.storageClass = SpirvStorageClass::Output; @@ -294,10 +265,10 @@ namespace Nz throw std::runtime_error("unexpected builtin type"); } - const ShaderExpressionType& builtinExprType = builtin->type; + const ShaderAst::ShaderExpressionType& builtinExprType = builtin->type; assert(IsBasicType(builtinExprType)); - ShaderNodes::BasicType builtinType = std::get(builtinExprType); + ShaderAst::BasicType builtinType = std::get(builtinExprType); variable.type = SpirvConstantCache::BuildPointerType(builtinType, variable.storageClass); @@ -420,7 +391,7 @@ namespace Nz if (!state.functionBlocks.back().IsTerminated()) { - assert(func.returnType == ShaderExpressionType(ShaderNodes::BasicType::Void)); + assert(func.returnType == ShaderAst::ShaderExpressionType(ShaderAst::BasicType::Void)); state.functionBlocks.back().Append(SpirvOp::OpReturn); } @@ -475,14 +446,14 @@ namespace Nz if (m_context.shader->GetStage() == ShaderStageType::Fragment) state.header.Append(SpirvOp::OpExecutionMode, entryFunc.id, SpvExecutionModeOriginUpperLeft); - } + }*/ std::vector ret; - MergeSections(ret, state.header); + /*MergeSections(ret, state.header); MergeSections(ret, state.debugInfo); MergeSections(ret, state.annotations); MergeSections(ret, state.constants); - MergeSections(ret, state.instructions); + MergeSections(ret, state.instructions);*/ return ret; } @@ -516,16 +487,16 @@ namespace Nz m_currentState->header.Append(SpirvOp::OpMemoryModel, SpvAddressingModelLogical, SpvMemoryModelGLSL450); } - SpirvConstantCache::Function SpirvWriter::BuildFunctionType(ShaderExpressionType retType, const std::vector& parameters) + SpirvConstantCache::Function SpirvWriter::BuildFunctionType(ShaderAst::ShaderExpressionType retType, const std::vector& parameters) { std::vector parameterTypes; parameterTypes.reserve(parameters.size()); for (const auto& parameter : parameters) - parameterTypes.push_back(SpirvConstantCache::BuildPointerType(*m_context.shader, parameter.type, SpirvStorageClass::Function)); + parameterTypes.push_back(SpirvConstantCache::BuildPointerType(parameter.type, SpirvStorageClass::Function)); return SpirvConstantCache::Function{ - SpirvConstantCache::BuildType(*m_context.shader, retType), + SpirvConstantCache::BuildType(retType), std::move(parameterTypes) }; } @@ -535,12 +506,12 @@ namespace Nz return m_currentState->constantTypeCache.GetId(*SpirvConstantCache::BuildConstant(value)); } - UInt32 SpirvWriter::GetFunctionTypeId(ShaderExpressionType retType, const std::vector& parameters) + UInt32 SpirvWriter::GetFunctionTypeId(ShaderAst::ShaderExpressionType retType, const std::vector& parameters) { return m_currentState->constantTypeCache.GetId({ BuildFunctionType(retType, parameters) }); } - auto SpirvWriter::GetBuiltinVariable(ShaderNodes::BuiltinEntry builtin) const -> const ExtVar& + auto SpirvWriter::GetBuiltinVariable(ShaderAst::BuiltinEntry builtin) const -> const ExtVar& { auto it = m_currentState->builtinIds.find(builtin); assert(it != m_currentState->builtinIds.end()); @@ -572,14 +543,14 @@ namespace Nz return it.value(); } - UInt32 SpirvWriter::GetPointerTypeId(const ShaderExpressionType& type, SpirvStorageClass storageClass) const + UInt32 SpirvWriter::GetPointerTypeId(const ShaderAst::ShaderExpressionType& type, SpirvStorageClass storageClass) const { - return m_currentState->constantTypeCache.GetId(*SpirvConstantCache::BuildPointerType(*m_context.shader, type, storageClass)); + return m_currentState->constantTypeCache.GetId(*SpirvConstantCache::BuildPointerType(type, storageClass)); } - UInt32 SpirvWriter::GetTypeId(const ShaderExpressionType& type) const + UInt32 SpirvWriter::GetTypeId(const ShaderAst::ShaderExpressionType& type) const { - return m_currentState->constantTypeCache.GetId(*SpirvConstantCache::BuildType(*m_context.shader, type)); + return m_currentState->constantTypeCache.GetId(*SpirvConstantCache::BuildType(type)); } UInt32 SpirvWriter::ReadInputVariable(const std::string& name) @@ -673,20 +644,20 @@ namespace Nz return m_currentState->constantTypeCache.Register(*SpirvConstantCache::BuildConstant(value)); } - UInt32 SpirvWriter::RegisterFunctionType(ShaderExpressionType retType, const std::vector& parameters) + UInt32 SpirvWriter::RegisterFunctionType(ShaderAst::ShaderExpressionType retType, const std::vector& parameters) { return m_currentState->constantTypeCache.Register({ BuildFunctionType(retType, parameters) }); } - UInt32 SpirvWriter::RegisterPointerType(ShaderExpressionType type, SpirvStorageClass storageClass) + UInt32 SpirvWriter::RegisterPointerType(ShaderAst::ShaderExpressionType type, SpirvStorageClass storageClass) { - return m_currentState->constantTypeCache.Register(*SpirvConstantCache::BuildPointerType(*m_context.shader, type, storageClass)); + return m_currentState->constantTypeCache.Register(*SpirvConstantCache::BuildPointerType(type, storageClass)); } - UInt32 SpirvWriter::RegisterType(ShaderExpressionType type) + UInt32 SpirvWriter::RegisterType(ShaderAst::ShaderExpressionType type) { assert(m_currentState); - return m_currentState->constantTypeCache.Register(*SpirvConstantCache::BuildType(*m_context.shader, type)); + return m_currentState->constantTypeCache.Register(*SpirvConstantCache::BuildType(type)); } void SpirvWriter::WriteLocalVariable(std::string name, UInt32 resultId)