Improve code

This commit is contained in:
Jérôme Leclercq
2020-07-19 21:08:43 +02:00
parent e342c88e64
commit 3c1c61fb5e
3 changed files with 53 additions and 31 deletions

View File

@@ -212,39 +212,12 @@ Nz::ShaderAst MainWindow::ToShader()
{
Nz::ShaderNodes::StatementPtr shaderAst = m_shaderGraph.ToAst();
//TODO: Put in another function
auto GetExpressionFromInOut = [&](PrimitiveType type)
{
switch (type)
{
case PrimitiveType::Bool: return Nz::ShaderNodes::BasicType::Boolean;
case PrimitiveType::Float1: return Nz::ShaderNodes::BasicType::Float1;
case PrimitiveType::Float2: return Nz::ShaderNodes::BasicType::Float2;
case PrimitiveType::Float3: return Nz::ShaderNodes::BasicType::Float3;
case PrimitiveType::Float4: return Nz::ShaderNodes::BasicType::Float4;
}
assert(false);
throw std::runtime_error("Unhandled input type");
};
auto GetExpressionFromTexture = [&](TextureType type)
{
switch (type)
{
case TextureType::Sampler2D: return Nz::ShaderNodes::BasicType::Sampler2D;
}
assert(false);
throw std::runtime_error("Unhandled texture type");
};
Nz::ShaderAst shader;
for (const auto& input : m_shaderGraph.GetInputs())
shader.AddInput(input.name, GetExpressionFromInOut(input.type), input.locationIndex);
shader.AddInput(input.name, m_shaderGraph.ToShaderExpressionType(input.type), input.locationIndex);
for (const auto& output : m_shaderGraph.GetOutputs())
shader.AddOutput(output.name, GetExpressionFromInOut(output.type), output.locationIndex);
shader.AddOutput(output.name, m_shaderGraph.ToShaderExpressionType(output.type), output.locationIndex);
for (const auto& buffer : m_shaderGraph.GetBuffers())
{
@@ -253,7 +226,7 @@ Nz::ShaderAst MainWindow::ToShader()
}
for (const auto& uniform : m_shaderGraph.GetTextures())
shader.AddUniform(uniform.name, GetExpressionFromTexture(uniform.type), uniform.bindingIndex, {});
shader.AddUniform(uniform.name, m_shaderGraph.ToShaderExpressionType(uniform.type), uniform.bindingIndex, {});
for (const auto& s : m_shaderGraph.GetStructs())
{
@@ -267,7 +240,7 @@ Nz::ShaderAst MainWindow::ToShader()
{
using T = std::decay_t<decltype(arg)>;
if constexpr (std::is_same_v<T, PrimitiveType>)
member.type = GetExpressionFromInOut(arg);
member.type = m_shaderGraph.ToShaderExpressionType(arg);
else if constexpr (std::is_same_v<T, std::size_t>)
member.type = m_shaderGraph.GetStruct(arg).name;
else