diff --git a/src/ShaderNode/ShaderGraph.cpp b/src/ShaderNode/ShaderGraph.cpp index 3648fed82..2d5ba933d 100644 --- a/src/ShaderNode/ShaderGraph.cpp +++ b/src/ShaderNode/ShaderGraph.cpp @@ -498,6 +498,24 @@ Nz::ShaderNodes::StatementPtr ShaderGraph::ToAst() return Nz::ShaderNodes::StatementBlock::Build(std::move(statements)); } +Nz::ShaderExpressionType ShaderGraph::ToShaderExpressionType(const std::variant& type) const +{ + return std::visit([&](auto&& arg) -> Nz::ShaderExpressionType + { + using T = std::decay_t; + if constexpr (std::is_same_v) + return ToShaderExpressionType(arg); + else if constexpr (std::is_same_v) + { + assert(arg < m_structs.size()); + const auto& s = m_structs[arg]; + return s.name; + } + else + static_assert(AlwaysFalse::value, "non-exhaustive visitor"); + }, type); +}; + void ShaderGraph::UpdateBuffer(std::size_t bufferIndex, std::string name, BufferType bufferType, std::size_t structIndex, std::size_t bindingIndex) { assert(bufferIndex < m_buffers.size()); @@ -565,9 +583,36 @@ void ShaderGraph::UpdateTexturePreview(std::size_t textureIndex, QImage preview) OnTexturePreviewUpdate(this, textureIndex); } +Nz::ShaderExpressionType ShaderGraph::ToShaderExpressionType(PrimitiveType type) +{ + switch (type) + { + case PrimitiveType::Bool: return Nz::ShaderNodes::BasicType::Boolean; + case PrimitiveType::Float1: return Nz::ShaderNodes::BasicType::Float1; + case PrimitiveType::Float2: return Nz::ShaderNodes::BasicType::Float2; + case PrimitiveType::Float3: return Nz::ShaderNodes::BasicType::Float3; + case PrimitiveType::Float4: return Nz::ShaderNodes::BasicType::Float4; + } + + assert(false); + throw std::runtime_error("Unhandled primitive type"); +} + +Nz::ShaderExpressionType ShaderGraph::ToShaderExpressionType(TextureType type) +{ + switch (type) + { + case TextureType::Sampler2D: return Nz::ShaderNodes::BasicType::Sampler2D; + } + + assert(false); + throw std::runtime_error("Unhandled texture type"); +} + std::shared_ptr ShaderGraph::BuildRegistry() { auto registry = std::make_shared(); + RegisterShaderNode(*this, registry, "Inputs"); RegisterShaderNode(*this, registry, "Casts"); RegisterShaderNode(*this, registry, "Casts"); RegisterShaderNode(*this, registry, "Casts"); diff --git a/src/ShaderNode/ShaderGraph.hpp b/src/ShaderNode/ShaderGraph.hpp index d0ff63316..0cab330d0 100644 --- a/src/ShaderNode/ShaderGraph.hpp +++ b/src/ShaderNode/ShaderGraph.hpp @@ -56,6 +56,7 @@ class ShaderGraph QJsonObject Save(); Nz::ShaderNodes::StatementPtr ToAst(); + Nz::ShaderExpressionType ToShaderExpressionType(const std::variant& type) const; void UpdateBuffer(std::size_t bufferIndex, std::string name, BufferType bufferType, std::size_t structIndex, std::size_t bindingIndex); void UpdateInput(std::size_t inputIndex, std::string name, PrimitiveType type, InputRole role, std::size_t roleIndex, std::size_t locationIndex); @@ -121,6 +122,9 @@ class ShaderGraph NazaraSignal(OnTexturePreviewUpdate, ShaderGraph*, std::size_t /*textureIndex*/); NazaraSignal(OnTextureUpdate, ShaderGraph*, std::size_t /*textureIndex*/); + static Nz::ShaderExpressionType ToShaderExpressionType(PrimitiveType type); + static Nz::ShaderExpressionType ToShaderExpressionType(TextureType type); + private: std::shared_ptr BuildRegistry(); diff --git a/src/ShaderNode/Widgets/MainWindow.cpp b/src/ShaderNode/Widgets/MainWindow.cpp index 813d1b5ab..7c7101047 100644 --- a/src/ShaderNode/Widgets/MainWindow.cpp +++ b/src/ShaderNode/Widgets/MainWindow.cpp @@ -212,39 +212,12 @@ Nz::ShaderAst MainWindow::ToShader() { Nz::ShaderNodes::StatementPtr shaderAst = m_shaderGraph.ToAst(); - //TODO: Put in another function - auto GetExpressionFromInOut = [&](PrimitiveType type) - { - switch (type) - { - case PrimitiveType::Bool: return Nz::ShaderNodes::BasicType::Boolean; - case PrimitiveType::Float1: return Nz::ShaderNodes::BasicType::Float1; - case PrimitiveType::Float2: return Nz::ShaderNodes::BasicType::Float2; - case PrimitiveType::Float3: return Nz::ShaderNodes::BasicType::Float3; - case PrimitiveType::Float4: return Nz::ShaderNodes::BasicType::Float4; - } - - assert(false); - throw std::runtime_error("Unhandled input type"); - }; - - auto GetExpressionFromTexture = [&](TextureType type) - { - switch (type) - { - case TextureType::Sampler2D: return Nz::ShaderNodes::BasicType::Sampler2D; - } - - assert(false); - throw std::runtime_error("Unhandled texture type"); - }; - Nz::ShaderAst shader; for (const auto& input : m_shaderGraph.GetInputs()) - shader.AddInput(input.name, GetExpressionFromInOut(input.type), input.locationIndex); + shader.AddInput(input.name, m_shaderGraph.ToShaderExpressionType(input.type), input.locationIndex); for (const auto& output : m_shaderGraph.GetOutputs()) - shader.AddOutput(output.name, GetExpressionFromInOut(output.type), output.locationIndex); + shader.AddOutput(output.name, m_shaderGraph.ToShaderExpressionType(output.type), output.locationIndex); for (const auto& buffer : m_shaderGraph.GetBuffers()) { @@ -253,7 +226,7 @@ Nz::ShaderAst MainWindow::ToShader() } for (const auto& uniform : m_shaderGraph.GetTextures()) - shader.AddUniform(uniform.name, GetExpressionFromTexture(uniform.type), uniform.bindingIndex, {}); + shader.AddUniform(uniform.name, m_shaderGraph.ToShaderExpressionType(uniform.type), uniform.bindingIndex, {}); for (const auto& s : m_shaderGraph.GetStructs()) { @@ -267,7 +240,7 @@ Nz::ShaderAst MainWindow::ToShader() { using T = std::decay_t; if constexpr (std::is_same_v) - member.type = GetExpressionFromInOut(arg); + member.type = m_shaderGraph.ToShaderExpressionType(arg); else if constexpr (std::is_same_v) member.type = m_shaderGraph.GetStruct(arg).name; else