diff --git a/include/Nazara/Shader/GlslWriter.hpp b/include/Nazara/Shader/GlslWriter.hpp index ad76fbe18..e7b92505c 100644 --- a/include/Nazara/Shader/GlslWriter.hpp +++ b/include/Nazara/Shader/GlslWriter.hpp @@ -16,7 +16,6 @@ #include #include #include -#include namespace Nz { @@ -31,7 +30,7 @@ namespace Nz GlslWriter(GlslWriter&&) = delete; ~GlslWriter() = default; - std::string Generate(const ShaderAst& shader) override; + std::string Generate(const ShaderAst& shader, const States& conditions = {}); void SetEnv(Environment environment); @@ -70,6 +69,8 @@ namespace Nz void Visit(ShaderNodes::BinaryOp& node) override; void Visit(ShaderNodes::BuiltinVariable& var) override; void Visit(ShaderNodes::Cast& node) override; + void Visit(ShaderNodes::ConditionalExpression& node) override; + void Visit(ShaderNodes::ConditionalStatement& node) override; void Visit(ShaderNodes::Constant& node) override; void Visit(ShaderNodes::DeclareVariable& node) override; void Visit(ShaderNodes::ExpressionStatement& node) override; @@ -91,6 +92,7 @@ namespace Nz { const ShaderAst* shader = nullptr; const ShaderAst::Function* currentFunction = nullptr; + const States* states = nullptr; }; struct State diff --git a/include/Nazara/Shader/ShaderAst.hpp b/include/Nazara/Shader/ShaderAst.hpp index 30ed90fac..5b80a1380 100644 --- a/include/Nazara/Shader/ShaderAst.hpp +++ b/include/Nazara/Shader/ShaderAst.hpp @@ -20,6 +20,7 @@ namespace Nz class NAZARA_SHADER_API ShaderAst { public: + struct Condition; struct Function; struct FunctionParameter; struct InputOutput; @@ -33,12 +34,16 @@ namespace Nz ShaderAst(ShaderAst&&) noexcept = default; ~ShaderAst() = default; + void AddCondition(std::string name); void AddFunction(std::string name, ShaderNodes::StatementPtr statement, std::vector parameters = {}, ShaderNodes::BasicType returnType = ShaderNodes::BasicType::Void); void AddInput(std::string name, ShaderExpressionType type, std::optional locationIndex = {}); void AddOutput(std::string name, ShaderExpressionType type, std::optional locationIndex = {}); void AddStruct(std::string name, std::vector members); void AddUniform(std::string name, ShaderExpressionType type, std::optional bindingIndex = {}, std::optional memoryLayout = {}); + inline const Condition& GetCondition(std::size_t i) const; + inline std::size_t GetConditionCount() const; + inline const std::vector& GetConditions() const; inline const Function& GetFunction(std::size_t i) const; inline std::size_t GetFunctionCount() const; inline const std::vector& GetFunctions() const; @@ -59,6 +64,11 @@ namespace Nz ShaderAst& operator=(const ShaderAst&) = default; ShaderAst& operator=(ShaderAst&&) noexcept = default; + struct Condition + { + std::string name; + }; + struct VariableBase { std::string name; @@ -101,6 +111,7 @@ namespace Nz }; private: + std::vector m_conditions; std::vector m_functions; std::vector m_inputs; std::vector m_outputs; diff --git a/include/Nazara/Shader/ShaderAst.inl b/include/Nazara/Shader/ShaderAst.inl index 4c6c42833..f00517db5 100644 --- a/include/Nazara/Shader/ShaderAst.inl +++ b/include/Nazara/Shader/ShaderAst.inl @@ -12,6 +12,22 @@ namespace Nz { } + inline auto Nz::ShaderAst::GetCondition(std::size_t i) const -> const Condition& + { + assert(i < m_functions.size()); + return m_conditions[i]; + } + + inline std::size_t ShaderAst::GetConditionCount() const + { + return m_conditions.size(); + } + + inline auto ShaderAst::GetConditions() const -> const std::vector& + { + return m_conditions; + } + inline auto ShaderAst::GetFunction(std::size_t i) const -> const Function& { assert(i < m_functions.size()); diff --git a/include/Nazara/Shader/ShaderAstCloner.hpp b/include/Nazara/Shader/ShaderAstCloner.hpp index 34336e9ba..afa7cb5b2 100644 --- a/include/Nazara/Shader/ShaderAstCloner.hpp +++ b/include/Nazara/Shader/ShaderAstCloner.hpp @@ -38,6 +38,8 @@ namespace Nz void Visit(ShaderNodes::BinaryOp& node) override; void Visit(ShaderNodes::Branch& node) override; void Visit(ShaderNodes::Cast& node) override; + void Visit(ShaderNodes::ConditionalExpression& node) override; + void Visit(ShaderNodes::ConditionalStatement& node) override; void Visit(ShaderNodes::Constant& node) override; void Visit(ShaderNodes::DeclareVariable& node) override; void Visit(ShaderNodes::ExpressionStatement& node) override; diff --git a/include/Nazara/Shader/ShaderAstRecursiveVisitor.hpp b/include/Nazara/Shader/ShaderAstRecursiveVisitor.hpp index 367f9b4ca..5dd1bc5fe 100644 --- a/include/Nazara/Shader/ShaderAstRecursiveVisitor.hpp +++ b/include/Nazara/Shader/ShaderAstRecursiveVisitor.hpp @@ -26,6 +26,8 @@ namespace Nz void Visit(ShaderNodes::BinaryOp& node) override; void Visit(ShaderNodes::Branch& node) override; void Visit(ShaderNodes::Cast& node) override; + void Visit(ShaderNodes::ConditionalExpression& node) override; + void Visit(ShaderNodes::ConditionalStatement& node) override; void Visit(ShaderNodes::Constant& node) override; void Visit(ShaderNodes::DeclareVariable& node) override; void Visit(ShaderNodes::ExpressionStatement& node) override; diff --git a/include/Nazara/Shader/ShaderAstSerializer.hpp b/include/Nazara/Shader/ShaderAstSerializer.hpp index ac03cc66d..e2db77fdf 100644 --- a/include/Nazara/Shader/ShaderAstSerializer.hpp +++ b/include/Nazara/Shader/ShaderAstSerializer.hpp @@ -31,6 +31,8 @@ namespace Nz void Serialize(ShaderNodes::BuiltinVariable& var); void Serialize(ShaderNodes::Branch& node); void Serialize(ShaderNodes::Cast& node); + void Serialize(ShaderNodes::ConditionalExpression& node); + void Serialize(ShaderNodes::ConditionalStatement& node); void Serialize(ShaderNodes::Constant& node); void Serialize(ShaderNodes::DeclareVariable& node); void Serialize(ShaderNodes::ExpressionStatement& node); diff --git a/include/Nazara/Shader/ShaderAstValidator.hpp b/include/Nazara/Shader/ShaderAstValidator.hpp index 8195494b6..579893aa4 100644 --- a/include/Nazara/Shader/ShaderAstValidator.hpp +++ b/include/Nazara/Shader/ShaderAstValidator.hpp @@ -41,6 +41,8 @@ namespace Nz void Visit(ShaderNodes::BinaryOp& node) override; void Visit(ShaderNodes::Branch& node) override; void Visit(ShaderNodes::Cast& node) override; + void Visit(ShaderNodes::ConditionalExpression& node) override; + void Visit(ShaderNodes::ConditionalStatement& node) override; void Visit(ShaderNodes::Constant& node) override; void Visit(ShaderNodes::DeclareVariable& node) override; void Visit(ShaderNodes::ExpressionStatement& node) override; diff --git a/include/Nazara/Shader/ShaderAstVisitor.hpp b/include/Nazara/Shader/ShaderAstVisitor.hpp index decd3a2f9..9ac8d764c 100644 --- a/include/Nazara/Shader/ShaderAstVisitor.hpp +++ b/include/Nazara/Shader/ShaderAstVisitor.hpp @@ -10,8 +10,6 @@ #include #include #include -#include -#include namespace Nz { @@ -23,16 +21,14 @@ namespace Nz ShaderAstVisitor(ShaderAstVisitor&&) = delete; virtual ~ShaderAstVisitor(); - void EnableCondition(const std::string& name, bool cond); - - bool IsConditionEnabled(const std::string& name) const; - void Visit(const ShaderNodes::NodePtr& node); virtual void Visit(ShaderNodes::AccessMember& node) = 0; virtual void Visit(ShaderNodes::AssignOp& node) = 0; virtual void Visit(ShaderNodes::BinaryOp& node) = 0; virtual void Visit(ShaderNodes::Branch& node) = 0; virtual void Visit(ShaderNodes::Cast& node) = 0; + virtual void Visit(ShaderNodes::ConditionalExpression& node) = 0; + virtual void Visit(ShaderNodes::ConditionalStatement& node) = 0; virtual void Visit(ShaderNodes::Constant& node) = 0; virtual void Visit(ShaderNodes::DeclareVariable& node) = 0; virtual void Visit(ShaderNodes::ExpressionStatement& node) = 0; @@ -44,9 +40,6 @@ namespace Nz ShaderAstVisitor& operator=(const ShaderAstVisitor&) = delete; ShaderAstVisitor& operator=(ShaderAstVisitor&&) = delete; - - private: - std::unordered_set m_conditions; }; } diff --git a/include/Nazara/Shader/ShaderAstVisitorExcept.hpp b/include/Nazara/Shader/ShaderAstVisitorExcept.hpp index 40635477c..0c58c5472 100644 --- a/include/Nazara/Shader/ShaderAstVisitorExcept.hpp +++ b/include/Nazara/Shader/ShaderAstVisitorExcept.hpp @@ -22,6 +22,8 @@ namespace Nz void Visit(ShaderNodes::BinaryOp& node) override; void Visit(ShaderNodes::Branch& node) override; void Visit(ShaderNodes::Cast& node) override; + void Visit(ShaderNodes::ConditionalExpression& node) override; + void Visit(ShaderNodes::ConditionalStatement& node) override; void Visit(ShaderNodes::Constant& node) override; void Visit(ShaderNodes::DeclareVariable& node) override; void Visit(ShaderNodes::ExpressionStatement& node) override; diff --git a/include/Nazara/Shader/ShaderBuilder.hpp b/include/Nazara/Shader/ShaderBuilder.hpp index 074434a44..0965c7f0c 100644 --- a/include/Nazara/Shader/ShaderBuilder.hpp +++ b/include/Nazara/Shader/ShaderBuilder.hpp @@ -50,6 +50,7 @@ namespace Nz::ShaderBuilder constexpr BuiltinBuilder Builtin; constexpr GenBuilder Block; constexpr GenBuilder Branch; + constexpr GenBuilder ConditionalExpression; constexpr GenBuilder ConditionalStatement; constexpr GenBuilder Constant; constexpr GenBuilder DeclareVariable; diff --git a/include/Nazara/Shader/ShaderEnums.hpp b/include/Nazara/Shader/ShaderEnums.hpp index ed322e3e6..57f76d0f5 100644 --- a/include/Nazara/Shader/ShaderEnums.hpp +++ b/include/Nazara/Shader/ShaderEnums.hpp @@ -78,6 +78,7 @@ namespace Nz::ShaderNodes Branch, Cast, Constant, + ConditionalExpression, ConditionalStatement, DeclareVariable, ExpressionStatement, diff --git a/include/Nazara/Shader/ShaderNodes.hpp b/include/Nazara/Shader/ShaderNodes.hpp index b3af5f1aa..8639dbc99 100644 --- a/include/Nazara/Shader/ShaderNodes.hpp +++ b/include/Nazara/Shader/ShaderNodes.hpp @@ -217,6 +217,20 @@ namespace Nz static inline std::shared_ptr Build(BasicType castTo, ExpressionPtr* expressions, std::size_t expressionCount); }; + struct NAZARA_SHADER_API ConditionalExpression : public Expression + { + inline ConditionalExpression(); + + ShaderExpressionType GetExpressionType() const override; + void Visit(ShaderAstVisitor& visitor) override; + + std::string conditionName; + ExpressionPtr falsePath; + ExpressionPtr truePath; + + static inline std::shared_ptr Build(std::string condition, ExpressionPtr truePath, ExpressionPtr falsePath); + }; + struct NAZARA_SHADER_API Constant : public Expression { inline Constant(); diff --git a/include/Nazara/Shader/ShaderNodes.inl b/include/Nazara/Shader/ShaderNodes.inl index 1e0817b62..b61841430 100644 --- a/include/Nazara/Shader/ShaderNodes.inl +++ b/include/Nazara/Shader/ShaderNodes.inl @@ -263,6 +263,20 @@ namespace Nz::ShaderNodes return node; } + inline ConditionalExpression::ConditionalExpression() : + Expression(NodeType::ConditionalExpression) + { + } + + inline std::shared_ptr ShaderNodes::ConditionalExpression::Build(std::string condition, ExpressionPtr truePath, ExpressionPtr falsePath) + { + auto node = std::make_shared(); + node->conditionName = std::move(condition); + node->falsePath = std::move(falsePath); + node->truePath = std::move(truePath); + + return node; + } inline Constant::Constant() : Expression(NodeType::Constant) diff --git a/include/Nazara/Shader/ShaderWriter.hpp b/include/Nazara/Shader/ShaderWriter.hpp index 0e896fa40..3fd25d7ff 100644 --- a/include/Nazara/Shader/ShaderWriter.hpp +++ b/include/Nazara/Shader/ShaderWriter.hpp @@ -10,6 +10,7 @@ #include #include #include +#include namespace Nz { @@ -18,12 +19,17 @@ namespace Nz class NAZARA_SHADER_API ShaderWriter { public: + struct States; + ShaderWriter() = default; ShaderWriter(const ShaderWriter&) = default; ShaderWriter(ShaderWriter&&) = default; virtual ~ShaderWriter(); - virtual std::string Generate(const ShaderAst& shader) = 0; + struct States + { + std::unordered_set enabledConditions; + }; }; } diff --git a/include/Nazara/Shader/SpirvAstVisitor.hpp b/include/Nazara/Shader/SpirvAstVisitor.hpp index 743dd5130..d7196950f 100644 --- a/include/Nazara/Shader/SpirvAstVisitor.hpp +++ b/include/Nazara/Shader/SpirvAstVisitor.hpp @@ -32,6 +32,8 @@ namespace Nz void Visit(ShaderNodes::AssignOp& node) override; void Visit(ShaderNodes::BinaryOp& node) override; void Visit(ShaderNodes::Cast& node) override; + void Visit(ShaderNodes::ConditionalExpression& node) override; + void Visit(ShaderNodes::ConditionalStatement& node) override; void Visit(ShaderNodes::Constant& node) override; void Visit(ShaderNodes::DeclareVariable& node) override; void Visit(ShaderNodes::ExpressionStatement& node) override; diff --git a/include/Nazara/Shader/SpirvWriter.hpp b/include/Nazara/Shader/SpirvWriter.hpp index 6b21de0ba..86f3f156b 100644 --- a/include/Nazara/Shader/SpirvWriter.hpp +++ b/include/Nazara/Shader/SpirvWriter.hpp @@ -23,7 +23,7 @@ namespace Nz { class SpirvSection; - class NAZARA_SHADER_API SpirvWriter + class NAZARA_SHADER_API SpirvWriter : public ShaderWriter { friend class SpirvAstVisitor; friend class SpirvExpressionLoad; @@ -38,7 +38,7 @@ namespace Nz SpirvWriter(SpirvWriter&&) = delete; ~SpirvWriter() = default; - std::vector Generate(const ShaderAst& shader); + std::vector Generate(const ShaderAst& shader, const States& conditions = {}); void SetEnv(Environment environment); @@ -66,6 +66,8 @@ namespace Nz UInt32 GetPointerTypeId(const ShaderExpressionType& type, SpirvStorageClass storageClass) const; UInt32 GetTypeId(const ShaderExpressionType& type) const; + inline bool IsConditionEnabled(const std::string& condition) const; + UInt32 ReadInputVariable(const std::string& name); std::optional ReadInputVariable(const std::string& name, OnlyCache); UInt32 ReadLocalVariable(const std::string& name); @@ -88,6 +90,7 @@ namespace Nz { const ShaderAst* shader = nullptr; const ShaderAst::Function* currentFunction = nullptr; + const States* states = nullptr; }; struct ExtVar diff --git a/include/Nazara/Shader/SpirvWriter.inl b/include/Nazara/Shader/SpirvWriter.inl index 26012e0d1..44babe815 100644 --- a/include/Nazara/Shader/SpirvWriter.inl +++ b/include/Nazara/Shader/SpirvWriter.inl @@ -7,6 +7,10 @@ namespace Nz { + inline bool SpirvWriter::IsConditionEnabled(const std::string& condition) const + { + return m_context.states->enabledConditions.find(condition) != m_context.states->enabledConditions.end(); + } } #include diff --git a/src/Nazara/Shader/GlslWriter.cpp b/src/Nazara/Shader/GlslWriter.cpp index ccea40fe7..248c92bb7 100644 --- a/src/Nazara/Shader/GlslWriter.cpp +++ b/src/Nazara/Shader/GlslWriter.cpp @@ -48,12 +48,13 @@ namespace Nz { } - std::string GlslWriter::Generate(const ShaderAst& shader) + std::string GlslWriter::Generate(const ShaderAst& shader, const States& conditions) { std::string error; if (!ValidateShader(shader, &error)) throw std::runtime_error("Invalid shader AST: " + error); + m_context.states = &conditions; m_context.shader = &shader; State state; @@ -461,6 +462,21 @@ namespace Nz Append(")"); } + + void GlslWriter::Visit(ShaderNodes::ConditionalExpression& node) + { + if (m_context.states->enabledConditions.count(node.conditionName) != 0) + Visit(node.truePath); + else + Visit(node.falsePath); + } + + void GlslWriter::Visit(ShaderNodes::ConditionalStatement& node) + { + if (m_context.states->enabledConditions.count(node.conditionName) != 0) + Visit(node.statement); + } + void GlslWriter::Visit(ShaderNodes::Constant& node) { std::visit([&](auto&& arg) diff --git a/src/Nazara/Shader/ShaderAst.cpp b/src/Nazara/Shader/ShaderAst.cpp index 236a5e28f..2f43ee23b 100644 --- a/src/Nazara/Shader/ShaderAst.cpp +++ b/src/Nazara/Shader/ShaderAst.cpp @@ -7,6 +7,12 @@ namespace Nz { + void ShaderAst::AddCondition(std::string name) + { + auto& conditionEntry = m_conditions.emplace_back(); + conditionEntry.name = std::move(name); + } + void ShaderAst::AddFunction(std::string name, ShaderNodes::StatementPtr statement, std::vector parameters, ShaderNodes::BasicType returnType) { auto& functionEntry = m_functions.emplace_back(); diff --git a/src/Nazara/Shader/ShaderAstCloner.cpp b/src/Nazara/Shader/ShaderAstCloner.cpp index 90c7c909d..cef0ac694 100644 --- a/src/Nazara/Shader/ShaderAstCloner.cpp +++ b/src/Nazara/Shader/ShaderAstCloner.cpp @@ -91,6 +91,16 @@ namespace Nz PushExpression(ShaderNodes::Cast::Build(node.exprType, expressions.data(), expressionCount)); } + void ShaderAstCloner::Visit(ShaderNodes::ConditionalExpression& node) + { + PushExpression(ShaderNodes::ConditionalExpression::Build(node.conditionName, CloneExpression(node.truePath), CloneExpression(node.falsePath))); + } + + void ShaderAstCloner::Visit(ShaderNodes::ConditionalStatement& node) + { + PushStatement(ShaderNodes::ConditionalStatement::Build(node.conditionName, CloneStatement(node.statement))); + } + void ShaderAstCloner::Visit(ShaderNodes::Constant& node) { PushExpression(ShaderNodes::Constant::Build(node.value)); diff --git a/src/Nazara/Shader/ShaderAstRecursiveVisitor.cpp b/src/Nazara/Shader/ShaderAstRecursiveVisitor.cpp index 98fdbfee6..b68d5361a 100644 --- a/src/Nazara/Shader/ShaderAstRecursiveVisitor.cpp +++ b/src/Nazara/Shader/ShaderAstRecursiveVisitor.cpp @@ -47,6 +47,17 @@ namespace Nz } } + void ShaderAstRecursiveVisitor::Visit(ShaderNodes::ConditionalExpression& node) + { + Visit(node.truePath); + Visit(node.falsePath); + } + + void ShaderAstRecursiveVisitor::Visit(ShaderNodes::ConditionalStatement& node) + { + Visit(node.statement); + } + void ShaderAstRecursiveVisitor::Visit(ShaderNodes::Constant& /*node*/) { /* Nothing to do */ diff --git a/src/Nazara/Shader/ShaderAstSerializer.cpp b/src/Nazara/Shader/ShaderAstSerializer.cpp index 4f1c214d2..5184d980d 100644 --- a/src/Nazara/Shader/ShaderAstSerializer.cpp +++ b/src/Nazara/Shader/ShaderAstSerializer.cpp @@ -47,6 +47,16 @@ namespace Nz Serialize(node); } + void Visit(ShaderNodes::ConditionalExpression& node) override + { + Serialize(node); + } + + void Visit(ShaderNodes::ConditionalStatement& node) override + { + Serialize(node); + } + void Visit(ShaderNodes::Constant& node) override { Serialize(node); @@ -179,6 +189,19 @@ namespace Nz Node(expr); } + void ShaderAstSerializerBase::Serialize(ShaderNodes::ConditionalExpression& node) + { + Value(node.conditionName); + Node(node.truePath); + Node(node.falsePath); + } + + void ShaderAstSerializerBase::Serialize(ShaderNodes::ConditionalStatement& node) + { + Value(node.conditionName); + Node(node.statement); + } + void ShaderAstSerializerBase::Serialize(ShaderNodes::Constant& node) { UInt32 typeIndex; @@ -306,6 +329,12 @@ namespace Nz } }; + // Conditions + m_stream << UInt32(shader.GetConditionCount()); + for (const auto& cond : shader.GetConditions()) + m_stream << cond.name; + + // Structs m_stream << UInt32(shader.GetStructCount()); for (const auto& s : shader.GetStructs()) { @@ -318,9 +347,11 @@ namespace Nz } } + // Inputs / Outputs SerializeInputOutput(shader.GetInputs()); SerializeInputOutput(shader.GetOutputs()); + // Uniforms m_stream << UInt32(shader.GetUniformCount()); for (const auto& uniform : shader.GetUniforms()) { @@ -336,6 +367,7 @@ namespace Nz m_stream << UInt32(uniform.memoryLayout.value()); } + // Functions m_stream << UInt32(shader.GetFunctionCount()); for (const auto& func : shader.GetFunctions()) { @@ -495,6 +527,18 @@ namespace Nz ShaderAst shader(static_cast(shaderStage)); + // Conditions + UInt32 conditionCount; + m_stream >> conditionCount; + for (UInt32 i = 0; i < conditionCount; ++i) + { + std::string conditionName; + Value(conditionName); + + shader.AddCondition(std::move(conditionName)); + } + + // Structs UInt32 structCount; m_stream >> structCount; for (UInt32 i = 0; i < structCount; ++i) @@ -514,6 +558,7 @@ namespace Nz shader.AddStruct(std::move(structName), std::move(members)); } + // Inputs UInt32 inputCount; m_stream >> inputCount; for (UInt32 i = 0; i < inputCount; ++i) @@ -529,6 +574,7 @@ namespace Nz shader.AddInput(std::move(inputName), std::move(inputType), location); } + // Outputs UInt32 outputCount; m_stream >> outputCount; for (UInt32 i = 0; i < outputCount; ++i) @@ -544,6 +590,7 @@ namespace Nz shader.AddOutput(std::move(outputName), std::move(outputType), location); } + // Uniforms UInt32 uniformCount; m_stream >> uniformCount; for (UInt32 i = 0; i < uniformCount; ++i) @@ -561,6 +608,7 @@ namespace Nz shader.AddUniform(std::move(name), std::move(type), std::move(binding), std::move(memLayout)); } + // Functions UInt32 funcCount; m_stream >> funcCount; for (UInt32 i = 0; i < funcCount; ++i) @@ -614,6 +662,7 @@ namespace Nz HandleType(Branch); HandleType(Cast); HandleType(Constant); + HandleType(ConditionalExpression); HandleType(ConditionalStatement); HandleType(DeclareVariable); HandleType(ExpressionStatement); diff --git a/src/Nazara/Shader/ShaderAstValidator.cpp b/src/Nazara/Shader/ShaderAstValidator.cpp index 95ae20df0..643ff2d85 100644 --- a/src/Nazara/Shader/ShaderAstValidator.cpp +++ b/src/Nazara/Shader/ShaderAstValidator.cpp @@ -241,6 +241,35 @@ namespace Nz ShaderAstRecursiveVisitor::Visit(node); } + void ShaderAstValidator::Visit(ShaderNodes::ConditionalExpression& node) + { + MandatoryNode(node.truePath); + MandatoryNode(node.falsePath); + + for (std::size_t i = 0; i < m_shader.GetConditionCount(); ++i) + { + const auto& condition = m_shader.GetCondition(i); + if (condition.name == node.conditionName) + return; + } + + throw AstError{ "Condition not found" }; + } + + void ShaderAstValidator::Visit(ShaderNodes::ConditionalStatement& node) + { + MandatoryNode(node.statement); + + for (std::size_t i = 0; i < m_shader.GetConditionCount(); ++i) + { + const auto& condition = m_shader.GetCondition(i); + if (condition.name == node.conditionName) + return; + } + + throw AstError{ "Condition not found" }; + } + void ShaderAstValidator::Visit(ShaderNodes::Constant& /*node*/) { } diff --git a/src/Nazara/Shader/ShaderAstVisitor.cpp b/src/Nazara/Shader/ShaderAstVisitor.cpp index 53325c9ca..719e3ad99 100644 --- a/src/Nazara/Shader/ShaderAstVisitor.cpp +++ b/src/Nazara/Shader/ShaderAstVisitor.cpp @@ -9,19 +9,6 @@ namespace Nz { ShaderAstVisitor::~ShaderAstVisitor() = default; - void ShaderAstVisitor::EnableCondition(const std::string& name, bool cond) - { - if (cond) - m_conditions.insert(name); - else - m_conditions.erase(name); - } - - bool ShaderAstVisitor::IsConditionEnabled(const std::string& name) const - { - return m_conditions.count(name) != 0; - } - void ShaderAstVisitor::Visit(const ShaderNodes::NodePtr& node) { node->Visit(*this); diff --git a/src/Nazara/Shader/ShaderAstVisitorExcept.cpp b/src/Nazara/Shader/ShaderAstVisitorExcept.cpp index 419d073bb..85ae826d1 100644 --- a/src/Nazara/Shader/ShaderAstVisitorExcept.cpp +++ b/src/Nazara/Shader/ShaderAstVisitorExcept.cpp @@ -33,6 +33,16 @@ namespace Nz throw std::runtime_error("unhandled Cast node"); } + void ShaderAstVisitorExcept::Visit(ShaderNodes::ConditionalExpression& /*node*/) + { + throw std::runtime_error("unhandled ConditionalExpression node"); + } + + void ShaderAstVisitorExcept::Visit(ShaderNodes::ConditionalStatement& /*node*/) + { + throw std::runtime_error("unhandled ConditionalStatement node"); + } + void ShaderAstVisitorExcept::Visit(ShaderNodes::Constant& /*node*/) { throw std::runtime_error("unhandled Constant node"); diff --git a/src/Nazara/Shader/ShaderNodes.cpp b/src/Nazara/Shader/ShaderNodes.cpp index bdc2897db..6530c1713 100644 --- a/src/Nazara/Shader/ShaderNodes.cpp +++ b/src/Nazara/Shader/ShaderNodes.cpp @@ -26,8 +26,7 @@ namespace Nz::ShaderNodes void ConditionalStatement::Visit(ShaderAstVisitor& visitor) { - if (visitor.IsConditionEnabled(conditionName)) - statement->Visit(visitor); + visitor.Visit(*this); } @@ -204,6 +203,18 @@ namespace Nz::ShaderNodes } + ShaderExpressionType ConditionalExpression::GetExpressionType() const + { + assert(truePath->GetExpressionType() == falsePath->GetExpressionType()); + return truePath->GetExpressionType(); + } + + void ConditionalExpression::Visit(ShaderAstVisitor& visitor) + { + visitor.Visit(*this); + } + + ExpressionCategory SwizzleOp::GetExpressionCategory() const { return expression->GetExpressionCategory(); diff --git a/src/Nazara/Shader/SpirvAstVisitor.cpp b/src/Nazara/Shader/SpirvAstVisitor.cpp index 4c419f987..2cb244496 100644 --- a/src/Nazara/Shader/SpirvAstVisitor.cpp +++ b/src/Nazara/Shader/SpirvAstVisitor.cpp @@ -313,6 +313,20 @@ namespace Nz PushResultId(resultId); } + void SpirvAstVisitor::Visit(ShaderNodes::ConditionalExpression& node) + { + if (m_writer.IsConditionEnabled(node.conditionName)) + Visit(node.truePath); + else + Visit(node.falsePath); + } + + void SpirvAstVisitor::Visit(ShaderNodes::ConditionalStatement& node) + { + if (m_writer.IsConditionEnabled(node.conditionName)) + Visit(node.statement); + } + void SpirvAstVisitor::Visit(ShaderNodes::Constant& node) { std::visit([&] (const auto& value) diff --git a/src/Nazara/Shader/SpirvWriter.cpp b/src/Nazara/Shader/SpirvWriter.cpp index 45e1785b2..c8ad4d4b4 100644 --- a/src/Nazara/Shader/SpirvWriter.cpp +++ b/src/Nazara/Shader/SpirvWriter.cpp @@ -33,7 +33,8 @@ namespace Nz using LocalContainer = std::unordered_set>; using ParameterContainer = std::unordered_set< std::shared_ptr>; - PreVisitor(SpirvConstantCache& constantCache) : + PreVisitor(const SpirvWriter::States& conditions, SpirvConstantCache& constantCache) : + m_conditions(conditions), m_constantCache(constantCache) { } @@ -49,6 +50,20 @@ namespace Nz ShaderAstRecursiveVisitor::Visit(node); } + void Visit(ShaderNodes::ConditionalExpression& node) override + { + if (m_conditions.enabledConditions.count(node.conditionName) != 0) + Visit(node.truePath); + else + Visit(node.falsePath); + } + + void Visit(ShaderNodes::ConditionalStatement& node) override + { + if (m_conditions.enabledConditions.count(node.conditionName) != 0) + Visit(node.statement); + } + void Visit(ShaderNodes::Constant& node) override { std::visit([&](auto&& arg) @@ -126,6 +141,7 @@ namespace Nz ParameterContainer paramVars; private: + const SpirvWriter::States& m_conditions; SpirvConstantCache& m_constantCache; }; @@ -193,13 +209,14 @@ namespace Nz { } - std::vector SpirvWriter::Generate(const ShaderAst& shader) + std::vector SpirvWriter::Generate(const ShaderAst& shader, const States& conditions) { std::string error; if (!ValidateShader(shader, &error)) throw std::runtime_error("Invalid shader AST: " + error); m_context.shader = &shader; + m_context.states = &conditions; State state; m_currentState = &state; @@ -212,7 +229,7 @@ namespace Nz ShaderAstCloner cloner; - PreVisitor preVisitor(state.constantTypeCache); + PreVisitor preVisitor(conditions, state.constantTypeCache); for (const auto& func : shader.GetFunctions()) { functionStatements.emplace_back(cloner.Clone(func.statement)); @@ -450,7 +467,7 @@ namespace Nz m_environment = std::move(environment); } - UInt32 Nz::SpirvWriter::AllocateResultId() + UInt32 SpirvWriter::AllocateResultId() { return m_currentState->nextVarIndex++; } diff --git a/src/ShaderNode/DataModels/ConditionalExpression.cpp b/src/ShaderNode/DataModels/ConditionalExpression.cpp new file mode 100644 index 000000000..d034ada49 --- /dev/null +++ b/src/ShaderNode/DataModels/ConditionalExpression.cpp @@ -0,0 +1,240 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +ConditionalExpression::ConditionalExpression(ShaderGraph& graph) : +ShaderNode(graph) +{ + m_onConditionListUpdateSlot.Connect(GetGraph().OnConditionListUpdate, [&](ShaderGraph*) { OnConditionListUpdate(); }); + m_onConditionUpdateSlot.Connect(GetGraph().OnConditionUpdate, [&](ShaderGraph*, std::size_t conditionIndex) + { + if (m_currentConditionIndex == conditionIndex) + { + UpdatePreview(); + Q_EMIT dataUpdated(0); + } + }); + + if (graph.GetConditionCount() > 0) + { + m_currentConditionIndex = 0; + UpdateConditionText(); + } + + EnablePreview(); + SetPreviewSize({ 128, 128 }); + UpdatePreview(); +} + +Nz::ShaderNodes::ExpressionPtr ConditionalExpression::GetExpression(Nz::ShaderNodes::ExpressionPtr* expressions, std::size_t count) const +{ + assert(count == 2); + + if (!m_currentConditionIndex) + throw std::runtime_error("no condition"); + + const ShaderGraph& graph = GetGraph(); + + const auto& conditionEntry = graph.GetCondition(*m_currentConditionIndex); + return Nz::ShaderBuilder::ConditionalExpression(conditionEntry.name, expressions[0], expressions[1]); +} + +QString ConditionalExpression::caption() const +{ + return "ConditionalExpression (" + QString::fromStdString(m_currentConditionText) + ")"; +} + +QString ConditionalExpression::name() const +{ + return "ConditionalExpression"; +} + +unsigned int ConditionalExpression::nPorts(QtNodes::PortType portType) const +{ + switch (portType) + { + case QtNodes::PortType::In: return 2; + case QtNodes::PortType::Out: return 1; + } + + return 0; +} + +void ConditionalExpression::BuildNodeEdition(QFormLayout* layout) +{ + ShaderNode::BuildNodeEdition(layout); + + QComboBox* conditionSelection = new QComboBox; + for (const auto& conditionEntry : GetGraph().GetConditions()) + conditionSelection->addItem(QString::fromStdString(conditionEntry.name)); + + if (m_currentConditionIndex) + conditionSelection->setCurrentIndex(int(*m_currentConditionIndex)); + else + conditionSelection->setCurrentIndex(-1); + + connect(conditionSelection, qOverload(&QComboBox::currentIndexChanged), [&](int index) + { + if (index >= 0) + m_currentConditionIndex = static_cast(index); + else + m_currentConditionIndex.reset(); + + UpdateConditionText(); + UpdatePreview(); + + Q_EMIT dataUpdated(0); + }); + + layout->addRow(tr("Condition"), conditionSelection); +} + +auto ConditionalExpression::dataType(QtNodes::PortType portType, QtNodes::PortIndex portIndex) const -> QtNodes::NodeDataType +{ + return VecData::Type(); + + assert(portType == QtNodes::PortType::Out); + assert(portIndex == 0); + + if (!m_truePath && !m_falsePath) + return VecData::Type(); + + return (m_truePath) ? m_truePath->type() : m_falsePath->type(); +} + +QString ConditionalExpression::portCaption(QtNodes::PortType portType, QtNodes::PortIndex portIndex) const +{ + switch (portType) + { + case QtNodes::PortType::In: + { + switch (portIndex) + { + case 0: + return "True path"; + + case 1: + return "False path"; + + default: + break; + } + } + + default: + break; + } + + return QString{}; +} + +bool ConditionalExpression::portCaptionVisible(QtNodes::PortType portType, QtNodes::PortIndex /*portIndex*/) const +{ + return portType == QtNodes::PortType::In; +} + +std::shared_ptr ConditionalExpression::outData(QtNodes::PortIndex port) +{ + if (!m_currentConditionIndex) + return nullptr; + + assert(port == 0); + return (GetGraph().IsConditionEnabled(*m_currentConditionIndex)) ? m_truePath : m_falsePath; +} + +void ConditionalExpression::restore(const QJsonObject& data) +{ + m_currentConditionText = data["condition_name"].toString().toStdString(); + OnConditionListUpdate(); + + ShaderNode::restore(data); +} + +QJsonObject ConditionalExpression::save() const +{ + QJsonObject data = ShaderNode::save(); + data["condition_name"] = QString::fromStdString(m_currentConditionText); + + return data; +} + +void ConditionalExpression::setInData(std::shared_ptr value, int index) +{ + assert(index == 0 || index == 1); + + if (index == 0) + m_truePath = std::move(value); + else + m_falsePath = std::move(value); + + UpdatePreview(); +} + +QtNodes::NodeValidationState ConditionalExpression::validationState() const +{ + if (!m_truePath || !m_falsePath) + return QtNodes::NodeValidationState::Error; + + return QtNodes::NodeValidationState::Valid; +} + +QString ConditionalExpression::validationMessage() const +{ + if (!m_truePath || !m_falsePath) + return "Missing input"; + + return QString(); +} + +bool ConditionalExpression::ComputePreview(QPixmap& pixmap) +{ + if (!m_currentConditionIndex) + return false; + + auto input = outData(0); + if (!input || input->type().id != VecData::Type().id) + return false; + + assert(dynamic_cast(input.get()) != nullptr); + const VecData& data = static_cast(*input); + + pixmap = QPixmap::fromImage(data.preview.GenerateImage()); + return true; +} + +void ConditionalExpression::OnConditionListUpdate() +{ + m_currentConditionIndex.reset(); + + std::size_t conditionIndex = 0; + for (const auto& conditionEntry : GetGraph().GetConditions()) + { + if (conditionEntry.name == m_currentConditionText) + { + m_currentConditionIndex = conditionIndex; + break; + } + + conditionIndex++; + } +} + +void ConditionalExpression::UpdateConditionText() +{ + if (m_currentConditionIndex) + { + auto& condition = GetGraph().GetCondition(*m_currentConditionIndex); + m_currentConditionText = condition.name; + } + else + m_currentConditionText.clear(); +} diff --git a/src/ShaderNode/DataModels/ConditionalExpression.hpp b/src/ShaderNode/DataModels/ConditionalExpression.hpp new file mode 100644 index 000000000..a89d6c40c --- /dev/null +++ b/src/ShaderNode/DataModels/ConditionalExpression.hpp @@ -0,0 +1,58 @@ +#pragma once + +#ifndef NAZARA_SHADERNODES_CONDITIONALEXPRESSION_HPP +#define NAZARA_SHADERNODES_CONDITIONALEXPRESSION_HPP + +#include +#include +#include +#include +#include + +class ConditionalExpression : public ShaderNode +{ + public: + ConditionalExpression(ShaderGraph& graph); + ~ConditionalExpression() = default; + + void BuildNodeEdition(QFormLayout* layout) override; + + Nz::ShaderNodes::ExpressionPtr GetExpression(Nz::ShaderNodes::ExpressionPtr* /*expressions*/, std::size_t count) const override; + + QString caption() const override; + QString name() const override; + + unsigned int nPorts(QtNodes::PortType portType) const override; + + QtNodes::NodeDataType dataType(QtNodes::PortType portType, QtNodes::PortIndex portIndex) const override; + + QString portCaption(QtNodes::PortType portType, QtNodes::PortIndex portIndex) const override; + bool portCaptionVisible(QtNodes::PortType portType, QtNodes::PortIndex portIndex) const override; + + std::shared_ptr outData(QtNodes::PortIndex port) override; + + void restore(const QJsonObject& data) override; + QJsonObject save() const override; + + void setInData(std::shared_ptr value, int index) override; + + QtNodes::NodeValidationState validationState() const override; + QString validationMessage() const override; + + private: + bool ComputePreview(QPixmap& pixmap) override; + void OnConditionListUpdate(); + void UpdateConditionText(); + + NazaraSlot(ShaderGraph, OnConditionListUpdate, m_onConditionListUpdateSlot); + NazaraSlot(ShaderGraph, OnConditionUpdate, m_onConditionUpdateSlot); + + std::optional m_currentConditionIndex; + std::shared_ptr m_falsePath; + std::shared_ptr m_truePath; + std::string m_currentConditionText; +}; + +#include + +#endif diff --git a/src/ShaderNode/DataModels/ConditionalExpression.inl b/src/ShaderNode/DataModels/ConditionalExpression.inl new file mode 100644 index 000000000..67cd1a5e7 --- /dev/null +++ b/src/ShaderNode/DataModels/ConditionalExpression.inl @@ -0,0 +1,2 @@ +#include +#include diff --git a/src/ShaderNode/DataModels/OutputValue.cpp b/src/ShaderNode/DataModels/OutputValue.cpp index 9a2cb7030..7ffb4df91 100644 --- a/src/ShaderNode/DataModels/OutputValue.cpp +++ b/src/ShaderNode/DataModels/OutputValue.cpp @@ -238,16 +238,16 @@ void OutputValue::OnOutputListUpdate() { m_currentOutputIndex.reset(); - std::size_t inputIndex = 0; - for (const auto& inputEntry : GetGraph().GetOutputs()) + std::size_t outputIndex = 0; + for (const auto& outputEntry : GetGraph().GetOutputs()) { - if (inputEntry.name == m_currentOutputText) + if (outputEntry.name == m_currentOutputText) { - m_currentOutputIndex = inputIndex; + m_currentOutputIndex = outputIndex; break; } - inputIndex++; + outputIndex++; } } diff --git a/src/ShaderNode/DataModels/ShaderNode.cpp b/src/ShaderNode/DataModels/ShaderNode.cpp index 7ed3bc84c..e596a6a6d 100644 --- a/src/ShaderNode/DataModels/ShaderNode.cpp +++ b/src/ShaderNode/DataModels/ShaderNode.cpp @@ -8,6 +8,7 @@ ShaderNode::ShaderNode(ShaderGraph& graph) : m_previewSize(64, 64), m_pixmapLabel(nullptr), +m_embeddedWidget(nullptr), m_graph(graph), m_enableCustomVariableName(true), m_isPreviewEnabled(false) @@ -86,7 +87,24 @@ void ShaderNode::EnablePreview(bool enable) QWidget* ShaderNode::embeddedWidget() { - return m_pixmapLabel; + if (!m_embeddedWidget) + { + QWidget* embedded = EmbeddedWidget(); + if (embedded) + { + QVBoxLayout* layout = new QVBoxLayout; + layout->addWidget(embedded); + layout->addWidget(m_pixmapLabel); + + m_embeddedWidget = new QWidget; + m_embeddedWidget->setStyleSheet("background-color: rgba(0,0,0,0)"); + m_embeddedWidget->setLayout(layout); + } + else + m_embeddedWidget = m_pixmapLabel; + } + + return m_embeddedWidget; } void ShaderNode::restore(const QJsonObject& data) @@ -121,6 +139,11 @@ bool ShaderNode::ComputePreview(QPixmap& /*pixmap*/) return false; } +QWidget* ShaderNode::EmbeddedWidget() +{ + return nullptr; +} + void ShaderNode::UpdatePreview() { if (!m_pixmap) diff --git a/src/ShaderNode/DataModels/ShaderNode.hpp b/src/ShaderNode/DataModels/ShaderNode.hpp index 6d3dd5cc1..874fce54f 100644 --- a/src/ShaderNode/DataModels/ShaderNode.hpp +++ b/src/ShaderNode/DataModels/ShaderNode.hpp @@ -41,6 +41,7 @@ class ShaderNode : public QtNodes::NodeDataModel protected: inline void DisableCustomVariableName(); inline void EnableCustomVariableName(bool enable = true); + virtual QWidget* EmbeddedWidget(); void UpdatePreview(); private: @@ -48,6 +49,7 @@ class ShaderNode : public QtNodes::NodeDataModel Nz::Vector2i m_previewSize; QLabel* m_pixmapLabel; + QWidget* m_embeddedWidget; std::optional m_pixmap; std::string m_variableName; ShaderGraph& m_graph; diff --git a/src/ShaderNode/ShaderGraph.cpp b/src/ShaderNode/ShaderGraph.cpp index a1799df3b..9a93dcf8f 100644 --- a/src/ShaderNode/ShaderGraph.cpp +++ b/src/ShaderNode/ShaderGraph.cpp @@ -2,6 +2,7 @@ #include #include #include +#include #include #include #include @@ -124,6 +125,17 @@ std::size_t ShaderGraph::AddBuffer(std::string name, BufferType bufferType, std: return index; } +std::size_t ShaderGraph::AddCondition(std::string name) +{ + std::size_t index = m_conditions.size(); + auto& conditionEntry = m_conditions.emplace_back(); + conditionEntry.name = std::move(name); + + OnConditionListUpdate(this); + + return index; +} + std::size_t ShaderGraph::AddInput(std::string name, PrimitiveType type, InputRole role, std::size_t roleIndex, std::size_t locationIndex) { std::size_t index = m_inputs.size(); @@ -185,6 +197,7 @@ void ShaderGraph::Clear() m_flowScene.clear(); m_buffers.clear(); + m_conditions.clear(); m_inputs.clear(); m_structs.clear(); m_outputs.clear(); @@ -197,6 +210,15 @@ void ShaderGraph::Clear() OnTextureListUpdate(this); } +void ShaderGraph::EnableCondition(std::size_t conditionIndex, bool enable) +{ + assert(conditionIndex < m_conditions.size()); + auto& conditionEntry = m_conditions[conditionIndex]; + conditionEntry.enabled = enable; + + OnConditionUpdate(this, conditionIndex); +} + void ShaderGraph::Load(const QJsonObject& data) { Clear(); @@ -218,6 +240,17 @@ void ShaderGraph::Load(const QJsonObject& data) OnBufferListUpdate(this); + QJsonArray conditionArray = data["conditions"].toArray(); + for (const auto& conditionDocRef : conditionArray) + { + QJsonObject conditionDoc = conditionDocRef.toObject(); + + ConditionEntry& condition = m_conditions.emplace_back(); + condition.name = conditionDoc["name"].toString().toStdString(); + } + + OnConditionListUpdate(this); + QJsonArray inputArray = data["inputs"].toArray(); for (const auto& inputDocRef : inputArray) { @@ -312,6 +345,18 @@ QJsonObject ShaderGraph::Save() } sceneJson["buffers"] = bufferArray; + QJsonArray conditionArray; + { + for (const auto& condition : m_conditions) + { + QJsonObject inputDoc; + inputDoc["name"] = QString::fromStdString(condition.name); + + conditionArray.append(inputDoc); + } + } + sceneJson["conditions"] = conditionArray; + QJsonArray inputArray; { for (const auto& input : m_inputs) @@ -436,10 +481,15 @@ Nz::ShaderNodes::StatementPtr ShaderGraph::ToAst() (*it)++; }; + std::vector outputNodes; + m_flowScene.iterateOverNodes([&](QtNodes::Node* node) { if (node->nodeDataModel()->nPorts(QtNodes::PortType::Out) == 0) + { DetectVariables(node); + outputNodes.push_back(node); + } }); QHash variableExpressions; @@ -510,13 +560,8 @@ Nz::ShaderNodes::StatementPtr ShaderGraph::ToAst() return expression; }; - m_flowScene.iterateOverNodes([&](QtNodes::Node* node) - { - if (node->nodeDataModel()->nPorts(QtNodes::PortType::Out) == 0) - { - statements.emplace_back(Nz::ShaderBuilder::ExprStatement(HandleNode(node))); - } - }); + for (QtNodes::Node* node : outputNodes) + statements.emplace_back(Nz::ShaderBuilder::ExprStatement(HandleNode(node))); return Nz::ShaderNodes::StatementBlock::Build(std::move(statements)); } @@ -551,6 +596,15 @@ void ShaderGraph::UpdateBuffer(std::size_t bufferIndex, std::string name, Buffer OnBufferUpdate(this, bufferIndex); } +void ShaderGraph::UpdateCondition(std::size_t conditionIndex, std::string condition) +{ + assert(conditionIndex < m_conditions.size()); + auto& conditionEntry = m_conditions[conditionIndex]; + conditionEntry.name = std::move(condition); + + OnConditionUpdate(this, conditionIndex); +} + void ShaderGraph::UpdateInput(std::size_t inputIndex, std::string name, PrimitiveType type, InputRole role, std::size_t roleIndex, std::size_t locationIndex) { assert(inputIndex < m_inputs.size()); @@ -687,6 +741,7 @@ std::shared_ptr ShaderGraph::BuildRegistry() RegisterShaderNode(*this, registry, "Casts"); RegisterShaderNode(*this, registry, "Casts"); RegisterShaderNode(*this, registry, "Casts"); + RegisterShaderNode(*this, registry, "Shader"); RegisterShaderNode(*this, registry, "Constants"); RegisterShaderNode(*this, registry, "Inputs"); RegisterShaderNode(*this, registry, "Outputs"); diff --git a/src/ShaderNode/ShaderGraph.hpp b/src/ShaderNode/ShaderGraph.hpp index 1df336ce6..44fb4bf77 100644 --- a/src/ShaderNode/ShaderGraph.hpp +++ b/src/ShaderNode/ShaderGraph.hpp @@ -18,6 +18,7 @@ class ShaderGraph { public: struct BufferEntry; + struct ConditionEntry; struct InputEntry; struct OutputEntry; struct StructEntry; @@ -28,6 +29,7 @@ class ShaderGraph ~ShaderGraph(); std::size_t AddBuffer(std::string name, BufferType bufferType, std::size_t structIndex, std::size_t bindingIndex); + std::size_t AddCondition(std::string name); std::size_t AddInput(std::string name, PrimitiveType type, InputRole role, std::size_t roleIndex, std::size_t locationIndex); std::size_t AddOutput(std::string name, PrimitiveType type, std::size_t locationIndex); std::size_t AddStruct(std::string name, std::vector members); @@ -35,9 +37,14 @@ class ShaderGraph void Clear(); + void EnableCondition(std::size_t conditionIndex, bool enable); + inline const BufferEntry& GetBuffer(std::size_t bufferIndex) const; inline std::size_t GetBufferCount() const; inline const std::vector& GetBuffers() const; + inline const ConditionEntry& GetCondition(std::size_t conditionIndex) const; + inline std::size_t GetConditionCount() const; + inline const std::vector& GetConditions() const; inline const InputEntry& GetInput(std::size_t bufferIndex) const; inline std::size_t GetInputCount() const; inline const std::vector& GetInputs() const; @@ -54,6 +61,8 @@ class ShaderGraph inline const std::vector& GetTextures() const; inline ShaderType GetType() const; + inline bool IsConditionEnabled(std::size_t conditionIndex) const; + void Load(const QJsonObject& data); QJsonObject Save(); @@ -61,6 +70,7 @@ class ShaderGraph Nz::ShaderExpressionType ToShaderExpressionType(const std::variant& type) const; void UpdateBuffer(std::size_t bufferIndex, std::string name, BufferType bufferType, std::size_t structIndex, std::size_t bindingIndex); + void UpdateCondition(std::size_t conditionIndex, std::string condition); void UpdateInput(std::size_t inputIndex, std::string name, PrimitiveType type, InputRole role, std::size_t roleIndex, std::size_t locationIndex); void UpdateOutput(std::size_t outputIndex, std::string name, PrimitiveType type, std::size_t locationIndex); void UpdateStruct(std::size_t structIndex, std::string name, std::vector members); @@ -76,6 +86,12 @@ class ShaderGraph BufferType type; }; + struct ConditionEntry + { + std::string name; + bool enabled = false; + }; + struct InputEntry { std::size_t locationIndex; @@ -113,7 +129,9 @@ class ShaderGraph }; NazaraSignal(OnBufferListUpdate, ShaderGraph*); - NazaraSignal(OnBufferUpdate, ShaderGraph*, std::size_t /*outputIndex*/); + NazaraSignal(OnBufferUpdate, ShaderGraph*, std::size_t /*bufferIndex*/); + NazaraSignal(OnConditionListUpdate, ShaderGraph*); + NazaraSignal(OnConditionUpdate, ShaderGraph*, std::size_t /*conditionIndex*/); NazaraSignal(OnInputListUpdate, ShaderGraph*); NazaraSignal(OnInputUpdate, ShaderGraph*, std::size_t /*inputIndex*/); NazaraSignal(OnOutputListUpdate, ShaderGraph*); @@ -136,6 +154,7 @@ class ShaderGraph QtNodes::FlowScene m_flowScene; std::vector m_buffers; + std::vector m_conditions; std::vector m_inputs; std::vector m_outputs; std::vector m_structs; diff --git a/src/ShaderNode/ShaderGraph.inl b/src/ShaderNode/ShaderGraph.inl index afa867dcd..11ba6f5dc 100644 --- a/src/ShaderNode/ShaderGraph.inl +++ b/src/ShaderNode/ShaderGraph.inl @@ -16,6 +16,22 @@ inline auto ShaderGraph::GetBuffers() const -> const std::vector& return m_buffers; } +inline auto ShaderGraph::GetCondition(std::size_t conditionIndex) const -> const ConditionEntry& +{ + assert(conditionIndex < m_conditions.size()); + return m_conditions[conditionIndex]; +} + +inline std::size_t ShaderGraph::GetConditionCount() const +{ + return m_conditions.size(); +} + +inline auto ShaderGraph::GetConditions() const -> const std::vector& +{ + return m_conditions; +} + inline auto ShaderGraph::GetInput(std::size_t inputIndex) const -> const InputEntry& { assert(inputIndex < m_inputs.size()); @@ -95,3 +111,9 @@ inline ShaderType ShaderGraph::GetType() const return m_type; } +inline bool ShaderGraph::IsConditionEnabled(std::size_t conditionIndex) const +{ + assert(conditionIndex < m_conditions.size()); + return m_conditions[conditionIndex].enabled; +} + diff --git a/src/ShaderNode/Widgets/ConditionEditDialog.cpp b/src/ShaderNode/Widgets/ConditionEditDialog.cpp new file mode 100644 index 000000000..4d04932a6 --- /dev/null +++ b/src/ShaderNode/Widgets/ConditionEditDialog.cpp @@ -0,0 +1,55 @@ +#include +#include +#include +#include +#include +#include +#include +#include + +ConditionEditDialog::ConditionEditDialog(QWidget* parent) : +QDialog(parent) +{ + setWindowTitle(tr("Condition edit dialog")); + setWindowFlags(windowFlags() & ~Qt::WindowContextHelpButtonHint); + + m_conditionName = new QLineEdit; + + QFormLayout* formLayout = new QFormLayout; + formLayout->addRow(tr("Name"), m_conditionName); + + QDialogButtonBox* buttonBox = new QDialogButtonBox(QDialogButtonBox::Ok | QDialogButtonBox::Cancel); + connect(buttonBox, &QDialogButtonBox::accepted, this, &ConditionEditDialog::OnAccept); + connect(buttonBox, &QDialogButtonBox::rejected, this, &QDialog::reject); + + QVBoxLayout* verticalLayout = new QVBoxLayout; + verticalLayout->addLayout(formLayout); + verticalLayout->addWidget(buttonBox); + + setLayout(verticalLayout); +} + +ConditionEditDialog::ConditionEditDialog(const ConditionInfo& condition, QWidget* parent) : +ConditionEditDialog(parent) +{ + m_conditionName->setText(QString::fromStdString(condition.name)); +} + +ConditionInfo ConditionEditDialog::GetConditionInfo() const +{ + ConditionInfo inputInfo; + inputInfo.name = m_conditionName->text().toStdString(); + + return inputInfo; +} + +void ConditionEditDialog::OnAccept() +{ + if (m_conditionName->text().isEmpty()) + { + QMessageBox::critical(this, tr("Empty name"), tr("Condition name must be set"), QMessageBox::Ok); + return; + } + + accept(); +} diff --git a/src/ShaderNode/Widgets/ConditionEditDialog.hpp b/src/ShaderNode/Widgets/ConditionEditDialog.hpp new file mode 100644 index 000000000..22cebb2f8 --- /dev/null +++ b/src/ShaderNode/Widgets/ConditionEditDialog.hpp @@ -0,0 +1,35 @@ +#pragma once + +#ifndef NAZARA_SHADERNODES_CONDITIONEDITDIALOG_HPP +#define NAZARA_SHADERNODES_CONDITIONEDITDIALOG_HPP + +#include +#include + +class QComboBox; +class QLineEdit; +class QSpinBox; + +struct ConditionInfo +{ + std::string name; +}; + +class ConditionEditDialog : public QDialog +{ + public: + ConditionEditDialog(QWidget* parent = nullptr); + ConditionEditDialog(const ConditionInfo& input, QWidget* parent = nullptr); + ~ConditionEditDialog() = default; + + ConditionInfo GetConditionInfo() const; + + private: + void OnAccept(); + + QLineEdit* m_conditionName; +}; + +#include + +#endif diff --git a/src/ShaderNode/Widgets/ConditionEditDialog.inl b/src/ShaderNode/Widgets/ConditionEditDialog.inl new file mode 100644 index 000000000..157a5289b --- /dev/null +++ b/src/ShaderNode/Widgets/ConditionEditDialog.inl @@ -0,0 +1 @@ +#include diff --git a/src/ShaderNode/Widgets/ConditionEditor.cpp b/src/ShaderNode/Widgets/ConditionEditor.cpp new file mode 100644 index 000000000..2b8c40ab8 --- /dev/null +++ b/src/ShaderNode/Widgets/ConditionEditor.cpp @@ -0,0 +1,117 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +ConditionEditor::ConditionEditor(ShaderGraph& graph) : +m_shaderGraph(graph) +{ + QTableView* tableView = new QTableView; + m_model = new QStandardItemModel(0, 2, tableView); + tableView->setModel(m_model); + + m_model->setHorizontalHeaderLabels({ tr("Condition"), tr("Enabled") }); + + connect(tableView, &QTableView::doubleClicked, [this](const QModelIndex& index) + { + if (index.column() == 0) + OnEditCondition(index.row()); + }); + + connect(m_model, &QStandardItemModel::itemChanged, [this](QStandardItem* item) + { + if (item->column() == 1) + { + std::size_t conditionIndex = static_cast(item->row()); + bool value = item->checkState() == Qt::Checked; + m_shaderGraph.EnableCondition(conditionIndex, value); + } + }); + + + QPushButton* addStructButton = new QPushButton(tr("Add condition...")); + connect(addStructButton, &QPushButton::released, this, &ConditionEditor::OnAddCondition); + + m_layout = new QVBoxLayout; + m_layout->addWidget(tableView); + m_layout->addWidget(addStructButton); + + setLayout(m_layout); + + m_onConditionListUpdateSlot.Connect(m_shaderGraph.OnConditionListUpdate, this, &ConditionEditor::OnConditionListUpdate); + m_onConditionUpdateSlot.Connect(m_shaderGraph.OnConditionUpdate, this, &ConditionEditor::OnConditionUpdate); + + RefreshConditions(); +} + +void ConditionEditor::OnAddCondition() +{ + ConditionEditDialog* dialog = new ConditionEditDialog(this); + dialog->setAttribute(Qt::WA_DeleteOnClose, true); + connect(dialog, &QDialog::accepted, [this, dialog] + { + ConditionInfo conditionInfo = dialog->GetConditionInfo(); + m_shaderGraph.AddCondition(std::move(conditionInfo.name)); + }); + + dialog->open(); +} + +void ConditionEditor::OnEditCondition(int conditionIndex) +{ + const auto& conditionInfo = m_shaderGraph.GetCondition(conditionIndex); + + ConditionInfo info; + info.name = conditionInfo.name; + + ConditionEditDialog* dialog = new ConditionEditDialog(info, this); + dialog->setAttribute(Qt::WA_DeleteOnClose, true); + connect(dialog, &QDialog::accepted, [this, dialog, conditionIndex] + { + ConditionInfo conditionInfo = dialog->GetConditionInfo(); + + m_shaderGraph.UpdateCondition(conditionIndex, std::move(conditionInfo.name)); + }); + + dialog->open(); +} + +void ConditionEditor::OnConditionListUpdate(ShaderGraph* /*graph*/) +{ + RefreshConditions(); +} + +void ConditionEditor::OnConditionUpdate(ShaderGraph* /*graph*/, std::size_t conditionIndex) +{ + const auto& conditionEntry = m_shaderGraph.GetCondition(conditionIndex); + + int row = int(conditionIndex); + m_model->item(row, 0)->setText(QString::fromStdString(conditionEntry.name)); + m_model->item(row, 1)->setCheckState((conditionEntry.enabled) ? Qt::CheckState::Checked : Qt::CheckState::Unchecked); +} + +void ConditionEditor::RefreshConditions() +{ + m_model->setRowCount(int(m_shaderGraph.GetConditionCount())); + + int rowIndex = 0; + for (const auto& conditionEntry : m_shaderGraph.GetConditions()) + { + QStandardItem* label = new QStandardItem(1); + label->setEditable(false); + label->setText(QString::fromStdString(conditionEntry.name)); + + m_model->setItem(rowIndex, 0, label); + + QStandardItem* checkbox = new QStandardItem(1); + checkbox->setCheckable(true); + checkbox->setCheckState((conditionEntry.enabled) ? Qt::CheckState::Checked : Qt::CheckState::Unchecked); + + m_model->setItem(rowIndex, 1, checkbox); + } +} diff --git a/src/ShaderNode/Widgets/ConditionEditor.hpp b/src/ShaderNode/Widgets/ConditionEditor.hpp new file mode 100644 index 000000000..b88b035fe --- /dev/null +++ b/src/ShaderNode/Widgets/ConditionEditor.hpp @@ -0,0 +1,36 @@ +#pragma once + +#ifndef NAZARA_SHADERNODES_CONDITIONEDITOR_HPP +#define NAZARA_SHADERNODES_CONDITIONEDITOR_HPP + +#include +#include +#include + +class QStandardItemModel; +class QVBoxLayout; + +class ConditionEditor : public QWidget +{ + public: + ConditionEditor(ShaderGraph& graph); + ~ConditionEditor() = default; + + private: + void OnAddCondition(); + void OnConditionListUpdate(ShaderGraph* graph); + void OnConditionUpdate(ShaderGraph* graph, std::size_t conditionIndex); + void OnEditCondition(int inputIndex); + void RefreshConditions(); + + NazaraSlot(ShaderGraph, OnStructListUpdate, m_onConditionListUpdateSlot); + NazaraSlot(ShaderGraph, OnStructUpdate, m_onConditionUpdateSlot); + + ShaderGraph& m_shaderGraph; + QStandardItemModel* m_model; + QVBoxLayout* m_layout; +}; + +#include + +#endif diff --git a/src/ShaderNode/Widgets/ConditionEditor.inl b/src/ShaderNode/Widgets/ConditionEditor.inl new file mode 100644 index 000000000..7367d2757 --- /dev/null +++ b/src/ShaderNode/Widgets/ConditionEditor.inl @@ -0,0 +1 @@ +#include diff --git a/src/ShaderNode/Widgets/MainWindow.cpp b/src/ShaderNode/Widgets/MainWindow.cpp index b88bef25f..a9bfcc59f 100644 --- a/src/ShaderNode/Widgets/MainWindow.cpp +++ b/src/ShaderNode/Widgets/MainWindow.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include #include @@ -84,6 +85,15 @@ m_shaderGraph(graph) addDockWidget(Qt::RightDockWidgetArea, structDock); + // Condition editor + ConditionEditor* conditionEditor = new ConditionEditor(m_shaderGraph); + + QDockWidget* conditionDock = new QDockWidget(tr("Conditions")); + conditionDock->setFeatures(QDockWidget::DockWidgetFloatable | QDockWidget::DockWidgetMovable); + conditionDock->setWidget(conditionEditor); + + addDockWidget(Qt::RightDockWidgetArea, conditionDock); + m_onSelectedNodeUpdate.Connect(m_shaderGraph.OnSelectedNodeUpdate, [&](ShaderGraph*, ShaderNode* node) { if (node) @@ -99,6 +109,21 @@ m_shaderGraph(graph) BuildMenu(); + + m_codeOutput = new QTextEdit; + m_codeOutput->setReadOnly(true); + m_codeOutput->setWindowTitle("GLSL Output"); + + m_onConditionUpdate.Connect(m_shaderGraph.OnConditionUpdate, [&](ShaderGraph*, std::size_t conditionIndex) + { + if (m_codeOutput->isVisible()) + OnGenerateGLSL(); + }); +} + +MainWindow::~MainWindow() +{ + delete m_codeOutput; } void MainWindow::BuildMenu() @@ -109,6 +134,7 @@ void MainWindow::BuildMenu() { QAction* loadShader = file->addAction(tr("Load...")); QObject::connect(loadShader, &QAction::triggered, this, &MainWindow::OnLoad); + QAction* saveShader = file->addAction(tr("Save...")); QObject::connect(saveShader, &QAction::triggered, this, &MainWindow::OnSave); } @@ -117,6 +143,7 @@ void MainWindow::BuildMenu() { QAction* settings = shader->addAction(tr("Settings...")); QObject::connect(settings, &QAction::triggered, this, &MainWindow::OnUpdateInfo); + QAction* compileShader = shader->addAction(tr("Compile...")); QObject::connect(compileShader, &QAction::triggered, this, &MainWindow::OnCompile); } @@ -155,16 +182,20 @@ void MainWindow::OnGenerateGLSL() try { Nz::GlslWriter writer; - std::string glsl = writer.Generate(ToShader()); + + Nz::GlslWriter::States states; + for (const auto& condition : m_shaderGraph.GetConditions()) + { + if (condition.enabled) + states.enabledConditions.insert(condition.name); + } + + std::string glsl = writer.Generate(ToShader(), states); std::cout << glsl << std::endl; - QTextEdit* output = new QTextEdit; - output->setReadOnly(true); - output->setText(QString::fromStdString(glsl)); - output->setAttribute(Qt::WA_DeleteOnClose, true); - output->setWindowTitle("GLSL Output"); - output->show(); + m_codeOutput->setText(QString::fromStdString(glsl)); + m_codeOutput->show(); } catch (const std::exception& e) { @@ -238,6 +269,9 @@ Nz::ShaderAst MainWindow::ToShader() Nz::ShaderNodes::StatementPtr shaderAst = m_shaderGraph.ToAst(); Nz::ShaderAst shader(ShaderGraph::ToShaderStageType(m_shaderGraph.GetType())); //< FIXME + for (const auto& condition : m_shaderGraph.GetConditions()) + shader.AddCondition(condition.name); + for (const auto& input : m_shaderGraph.GetInputs()) shader.AddInput(input.name, m_shaderGraph.ToShaderExpressionType(input.type), input.locationIndex); diff --git a/src/ShaderNode/Widgets/MainWindow.hpp b/src/ShaderNode/Widgets/MainWindow.hpp index 76615bb98..10d925190 100644 --- a/src/ShaderNode/Widgets/MainWindow.hpp +++ b/src/ShaderNode/Widgets/MainWindow.hpp @@ -8,6 +8,7 @@ #include class NodeEditor; +class QTextEdit; namespace Nz { @@ -18,7 +19,7 @@ class MainWindow : public QMainWindow { public: MainWindow(ShaderGraph& graph); - ~MainWindow() = default; + ~MainWindow(); private: void BuildMenu(); @@ -29,10 +30,12 @@ class MainWindow : public QMainWindow void OnUpdateInfo(); Nz::ShaderAst ToShader(); + NazaraSlot(ShaderGraph, OnConditionUpdate, m_onConditionUpdate); NazaraSlot(ShaderGraph, OnSelectedNodeUpdate, m_onSelectedNodeUpdate); NodeEditor* m_nodeEditor; ShaderGraph& m_shaderGraph; + QTextEdit* m_codeOutput; }; #include