Fix compilation

This commit is contained in:
Jérôme Leclercq
2022-03-04 18:16:12 +01:00
parent 0c3607579e
commit 1919bd3302
22 changed files with 322 additions and 136 deletions

View File

@@ -735,6 +735,24 @@ namespace Nz::ShaderAst
#undef EnableOptimisation
}
ModulePtr AstConstantPropagationVisitor::Process(const Module& shaderModule)
{
ModulePtr clone = std::make_shared<Module>();
clone->metadata = shaderModule.metadata;
clone->rootNode = static_unique_pointer_cast<MultiStatement>(Process(*shaderModule.rootNode));
return clone;
}
ModulePtr AstConstantPropagationVisitor::Process(const Module& shaderModule, const Options& options)
{
ModulePtr clone = std::make_shared<Module>();
clone->metadata = shaderModule.metadata;
clone->rootNode = static_unique_pointer_cast<MultiStatement>(Process(*shaderModule.rootNode, options));
return clone;
}
ExpressionPtr AstConstantPropagationVisitor::Clone(BinaryExpression& node)
{
auto lhs = CloneExpression(node.left);

View File

@@ -5,17 +5,33 @@
#include <Nazara/Shader/Ast/EliminateUnusedPassVisitor.hpp>
#include <Nazara/Core/CallOnExit.hpp>
#include <Nazara/Shader/ShaderBuilder.hpp>
#include <Nazara/Shader/Ast/AstRecursiveVisitor.hpp>
#include <unordered_map>
#include <Nazara/Shader/Debug.hpp>
namespace Nz::ShaderAst
{
namespace
{
template<typename T, typename U>
std::unique_ptr<T> static_unique_pointer_cast(std::unique_ptr<U>&& ptr)
{
return std::unique_ptr<T>(static_cast<T*>(ptr.release()));
}
}
struct EliminateUnusedPassVisitor::Context
{
const DependencyCheckerVisitor::UsageSet& usageSet;
};
ModulePtr EliminateUnusedPassVisitor::Process(const Module& shaderModule, const DependencyCheckerVisitor::UsageSet& usageSet)
{
ModulePtr clone = std::make_shared<Module>();
clone->metadata = shaderModule.metadata;
clone->rootNode = static_unique_pointer_cast<MultiStatement>(Process(*shaderModule.rootNode, usageSet));
return clone;
}
StatementPtr EliminateUnusedPassVisitor::Process(Statement& statement, const DependencyCheckerVisitor::UsageSet& usageSet)
{
Context context{

View File

@@ -112,7 +112,7 @@ namespace Nz::ShaderAst
std::vector<StatementPtr>* currentStatementList = nullptr;
};
ModulePtr SanitizeVisitor::Sanitize(Module& module, const Options& options, std::string* error)
ModulePtr SanitizeVisitor::Sanitize(const Module& module, const Options& options, std::string* error)
{
ModulePtr clone = std::make_shared<Module>();
clone->shaderLangVersion = module.shaderLangVersion;

View File

@@ -457,13 +457,18 @@ QJsonObject ShaderGraph::Save()
return sceneJson;
}
Nz::ShaderAst::StatementPtr ShaderGraph::ToAst() const
Nz::ShaderAst::ModulePtr ShaderGraph::ToModule() const
{
std::vector<Nz::ShaderAst::StatementPtr> statements;
Nz::ShaderAst::ModulePtr shaderModule = std::make_shared<Nz::ShaderAst::Module>();
std::shared_ptr<Nz::ShaderAst::Module::Metadata> moduleMetada = std::make_shared<Nz::ShaderAst::Module::Metadata>();
moduleMetada->shaderLangVersion = 100;
shaderModule->metadata = std::move(moduleMetada);
// Declare all options
for (const auto& option : m_options)
statements.push_back(Nz::ShaderBuilder::DeclareOption(option.name, Nz::ShaderAst::ExpressionType{ Nz::ShaderAst::PrimitiveType::Boolean }));
shaderModule->rootNode->statements.push_back(Nz::ShaderBuilder::DeclareOption(option.name, Nz::ShaderAst::ExpressionType{ Nz::ShaderAst::PrimitiveType::Boolean }));
// Declare all structures
for (const auto& structInfo : m_structs)
@@ -479,7 +484,7 @@ Nz::ShaderAst::StatementPtr ShaderGraph::ToAst() const
structMember.type = ToShaderExpressionType(memberInfo.type);
}
statements.push_back(Nz::ShaderBuilder::DeclareStruct(std::move(structDesc)));
shaderModule->rootNode->statements.push_back(Nz::ShaderBuilder::DeclareStruct(std::move(structDesc), false));
}
// External block
@@ -509,7 +514,7 @@ Nz::ShaderAst::StatementPtr ShaderGraph::ToAst() const
}
if (!external->externalVars.empty())
statements.push_back(std::move(external));
shaderModule->rootNode->statements.push_back(std::move(external));
// Inputs / outputs
if (!m_inputs.empty())
@@ -525,7 +530,7 @@ Nz::ShaderAst::StatementPtr ShaderGraph::ToAst() const
structMember.locationIndex = input.locationIndex;
}
statements.push_back(Nz::ShaderBuilder::DeclareStruct(std::move(structDesc)));
shaderModule->rootNode->statements.push_back(Nz::ShaderBuilder::DeclareStruct(std::move(structDesc), false));
}
if (!m_outputs.empty())
@@ -549,13 +554,13 @@ Nz::ShaderAst::StatementPtr ShaderGraph::ToAst() const
position.type = Nz::ShaderAst::ExpressionType{ Nz::ShaderAst::VectorType{ 4, Nz::ShaderAst::PrimitiveType::Float32 } };
}
statements.push_back(Nz::ShaderBuilder::DeclareStruct(std::move(structDesc)));
shaderModule->rootNode->statements.push_back(Nz::ShaderBuilder::DeclareStruct(std::move(structDesc), false));
}
// Functions
statements.push_back(ToFunction());
shaderModule->rootNode->statements.push_back(ToFunction());
return Nz::ShaderBuilder::MultiStatement(std::move(statements));
return shaderModule;
}
Nz::ShaderAst::ExpressionType ShaderGraph::ToShaderExpressionType(const std::variant<PrimitiveType, std::size_t>& type) const

View File

@@ -5,7 +5,7 @@
#include <Nazara/Core/Signal.hpp>
#include <Nazara/Utility/Enums.hpp>
#include <Nazara/Shader/Ast/Nodes.hpp>
#include <Nazara/Shader/Ast/Module.hpp>
#include <nodes/FlowScene>
#include <ShaderNode/Enums.hpp>
#include <ShaderNode/Previews/PreviewModel.hpp>
@@ -67,7 +67,7 @@ class ShaderGraph
void Load(const QJsonObject& data);
QJsonObject Save();
Nz::ShaderAst::StatementPtr ToAst() const;
Nz::ShaderAst::ModulePtr ToModule() const;
Nz::ShaderAst::ExpressionType ToShaderExpressionType(const std::variant<PrimitiveType, std::size_t>& type) const;
void UpdateBuffer(std::size_t bufferIndex, std::string name, BufferType bufferType, std::size_t structIndex, std::size_t setIndex, std::size_t bindingIndex);

View File

@@ -61,18 +61,18 @@ void CodeOutputWidget::Refresh()
for (std::size_t i = 0; i < m_shaderGraph.GetOptionCount(); ++i)
states.optionValues[i] = m_shaderGraph.IsOptionEnabled(i);
Nz::ShaderAst::StatementPtr shaderAst = m_shaderGraph.ToAst();
Nz::ShaderAst::ModulePtr shaderModule = m_shaderGraph.ToModule();
if (m_optimisationCheckbox->isChecked())
{
Nz::ShaderAst::SanitizeVisitor::Options sanitizeOptions;
sanitizeOptions.optionValues = states.optionValues;
shaderAst = Nz::ShaderAst::Sanitize(*shaderAst, sanitizeOptions);
shaderModule = Nz::ShaderAst::Sanitize(*shaderModule, sanitizeOptions);
Nz::ShaderAst::AstConstantPropagationVisitor optimiser;
shaderAst = Nz::ShaderAst::PropagateConstants(*shaderAst);
shaderAst = Nz::ShaderAst::EliminateUnusedPass(*shaderAst);
shaderModule = Nz::ShaderAst::PropagateConstants(*shaderModule);
shaderModule = Nz::ShaderAst::EliminateUnusedPass(*shaderModule);
}
std::string output;
@@ -89,21 +89,21 @@ void CodeOutputWidget::Refresh()
bindingMapping.emplace(Nz::UInt64(texture.setIndex) << 32 | Nz::UInt64(texture.bindingIndex), bindingMapping.size());
Nz::GlslWriter writer;
output = writer.Generate(ShaderGraph::ToShaderStageType(m_shaderGraph.GetType()), *shaderAst, bindingMapping, states);
output = writer.Generate(ShaderGraph::ToShaderStageType(m_shaderGraph.GetType()), *shaderModule, bindingMapping, states);
break;
}
case OutputLanguage::NZSL:
{
Nz::LangWriter writer;
output = writer.Generate(*shaderAst, states);
output = writer.Generate(*shaderModule, states);
break;
}
case OutputLanguage::SpirV:
{
Nz::SpirvWriter writer;
std::vector<std::uint32_t> spirv = writer.Generate(*shaderAst, states);
std::vector<std::uint32_t> spirv = writer.Generate(*shaderModule, states);
Nz::SpirvPrinter printer;
output = printer.Print(spirv.data(), spirv.size());

View File

@@ -181,7 +181,7 @@ void MainWindow::OnCompile()
{
try
{
auto shader = m_shaderGraph.ToAst();
auto shaderModule = m_shaderGraph.ToModule();
QString fileName = QFileDialog::getSaveFileName(nullptr, tr("Save shader"), QString(), tr("Shader Files (*.shader)"));
if (fileName.isEmpty())
@@ -191,7 +191,7 @@ void MainWindow::OnCompile()
fileName += ".shader";
Nz::File file(fileName.toStdString(), Nz::OpenMode::WriteOnly);
file.Write(Nz::ShaderAst::SerializeShader(shader));
file.Write(Nz::ShaderAst::SerializeShader(*shaderModule));
}
catch (const std::exception& e)
{