Improve code

This commit is contained in:
Jérôme Leclercq 2020-07-19 21:08:43 +02:00
parent e342c88e64
commit 3c1c61fb5e
3 changed files with 53 additions and 31 deletions

View File

@ -498,6 +498,24 @@ Nz::ShaderNodes::StatementPtr ShaderGraph::ToAst()
return Nz::ShaderNodes::StatementBlock::Build(std::move(statements));
}
Nz::ShaderExpressionType ShaderGraph::ToShaderExpressionType(const std::variant<PrimitiveType, std::size_t>& type) const
{
return std::visit([&](auto&& arg) -> Nz::ShaderExpressionType
{
using T = std::decay_t<decltype(arg)>;
if constexpr (std::is_same_v<T, PrimitiveType>)
return ToShaderExpressionType(arg);
else if constexpr (std::is_same_v<T, std::size_t>)
{
assert(arg < m_structs.size());
const auto& s = m_structs[arg];
return s.name;
}
else
static_assert(AlwaysFalse<T>::value, "non-exhaustive visitor");
}, type);
};
void ShaderGraph::UpdateBuffer(std::size_t bufferIndex, std::string name, BufferType bufferType, std::size_t structIndex, std::size_t bindingIndex)
{
assert(bufferIndex < m_buffers.size());
@ -565,9 +583,36 @@ void ShaderGraph::UpdateTexturePreview(std::size_t textureIndex, QImage preview)
OnTexturePreviewUpdate(this, textureIndex);
}
Nz::ShaderExpressionType ShaderGraph::ToShaderExpressionType(PrimitiveType type)
{
switch (type)
{
case PrimitiveType::Bool: return Nz::ShaderNodes::BasicType::Boolean;
case PrimitiveType::Float1: return Nz::ShaderNodes::BasicType::Float1;
case PrimitiveType::Float2: return Nz::ShaderNodes::BasicType::Float2;
case PrimitiveType::Float3: return Nz::ShaderNodes::BasicType::Float3;
case PrimitiveType::Float4: return Nz::ShaderNodes::BasicType::Float4;
}
assert(false);
throw std::runtime_error("Unhandled primitive type");
}
Nz::ShaderExpressionType ShaderGraph::ToShaderExpressionType(TextureType type)
{
switch (type)
{
case TextureType::Sampler2D: return Nz::ShaderNodes::BasicType::Sampler2D;
}
assert(false);
throw std::runtime_error("Unhandled texture type");
}
std::shared_ptr<QtNodes::DataModelRegistry> ShaderGraph::BuildRegistry()
{
auto registry = std::make_shared<QtNodes::DataModelRegistry>();
RegisterShaderNode<BufferField>(*this, registry, "Inputs");
RegisterShaderNode<CastToVec2>(*this, registry, "Casts");
RegisterShaderNode<CastToVec3>(*this, registry, "Casts");
RegisterShaderNode<CastToVec4>(*this, registry, "Casts");

View File

@ -56,6 +56,7 @@ class ShaderGraph
QJsonObject Save();
Nz::ShaderNodes::StatementPtr ToAst();
Nz::ShaderExpressionType ToShaderExpressionType(const std::variant<PrimitiveType, std::size_t>& type) const;
void UpdateBuffer(std::size_t bufferIndex, std::string name, BufferType bufferType, std::size_t structIndex, std::size_t bindingIndex);
void UpdateInput(std::size_t inputIndex, std::string name, PrimitiveType type, InputRole role, std::size_t roleIndex, std::size_t locationIndex);
@ -121,6 +122,9 @@ class ShaderGraph
NazaraSignal(OnTexturePreviewUpdate, ShaderGraph*, std::size_t /*textureIndex*/);
NazaraSignal(OnTextureUpdate, ShaderGraph*, std::size_t /*textureIndex*/);
static Nz::ShaderExpressionType ToShaderExpressionType(PrimitiveType type);
static Nz::ShaderExpressionType ToShaderExpressionType(TextureType type);
private:
std::shared_ptr<QtNodes::DataModelRegistry> BuildRegistry();

View File

@ -212,39 +212,12 @@ Nz::ShaderAst MainWindow::ToShader()
{
Nz::ShaderNodes::StatementPtr shaderAst = m_shaderGraph.ToAst();
//TODO: Put in another function
auto GetExpressionFromInOut = [&](PrimitiveType type)
{
switch (type)
{
case PrimitiveType::Bool: return Nz::ShaderNodes::BasicType::Boolean;
case PrimitiveType::Float1: return Nz::ShaderNodes::BasicType::Float1;
case PrimitiveType::Float2: return Nz::ShaderNodes::BasicType::Float2;
case PrimitiveType::Float3: return Nz::ShaderNodes::BasicType::Float3;
case PrimitiveType::Float4: return Nz::ShaderNodes::BasicType::Float4;
}
assert(false);
throw std::runtime_error("Unhandled input type");
};
auto GetExpressionFromTexture = [&](TextureType type)
{
switch (type)
{
case TextureType::Sampler2D: return Nz::ShaderNodes::BasicType::Sampler2D;
}
assert(false);
throw std::runtime_error("Unhandled texture type");
};
Nz::ShaderAst shader;
for (const auto& input : m_shaderGraph.GetInputs())
shader.AddInput(input.name, GetExpressionFromInOut(input.type), input.locationIndex);
shader.AddInput(input.name, m_shaderGraph.ToShaderExpressionType(input.type), input.locationIndex);
for (const auto& output : m_shaderGraph.GetOutputs())
shader.AddOutput(output.name, GetExpressionFromInOut(output.type), output.locationIndex);
shader.AddOutput(output.name, m_shaderGraph.ToShaderExpressionType(output.type), output.locationIndex);
for (const auto& buffer : m_shaderGraph.GetBuffers())
{
@ -253,7 +226,7 @@ Nz::ShaderAst MainWindow::ToShader()
}
for (const auto& uniform : m_shaderGraph.GetTextures())
shader.AddUniform(uniform.name, GetExpressionFromTexture(uniform.type), uniform.bindingIndex, {});
shader.AddUniform(uniform.name, m_shaderGraph.ToShaderExpressionType(uniform.type), uniform.bindingIndex, {});
for (const auto& s : m_shaderGraph.GetStructs())
{
@ -267,7 +240,7 @@ Nz::ShaderAst MainWindow::ToShader()
{
using T = std::decay_t<decltype(arg)>;
if constexpr (std::is_same_v<T, PrimitiveType>)
member.type = GetExpressionFromInOut(arg);
member.type = m_shaderGraph.ToShaderExpressionType(arg);
else if constexpr (std::is_same_v<T, std::size_t>)
member.type = m_shaderGraph.GetStruct(arg).name;
else