Shader/LangWriter: Make LangWriter able to output AST before sanitization as well

This commit is contained in:
Jérôme Leclercq
2022-03-12 18:16:30 +01:00
parent 2f26a1d9c7
commit 80f9556f8c
14 changed files with 192 additions and 57 deletions

View File

@@ -95,6 +95,7 @@ namespace Nz
void RegisterAlias(std::size_t aliasIndex, std::string aliasName); void RegisterAlias(std::size_t aliasIndex, std::string aliasName);
void RegisterConstant(std::size_t constantIndex, std::string constantName); void RegisterConstant(std::size_t constantIndex, std::string constantName);
void RegisterFunction(std::size_t funcIndex, std::string functionName);
void RegisterStruct(std::size_t structIndex, std::string structName); void RegisterStruct(std::size_t structIndex, std::string structName);
void RegisterVariable(std::size_t varIndex, std::string varName); void RegisterVariable(std::size_t varIndex, std::string varName);
@@ -107,10 +108,12 @@ namespace Nz
void Visit(ShaderAst::AliasValueExpression& node) override; void Visit(ShaderAst::AliasValueExpression& node) override;
void Visit(ShaderAst::AssignExpression& node) override; void Visit(ShaderAst::AssignExpression& node) override;
void Visit(ShaderAst::BinaryExpression& node) override; void Visit(ShaderAst::BinaryExpression& node) override;
void Visit(ShaderAst::CallFunctionExpression& node) override;
void Visit(ShaderAst::CastExpression& node) override; void Visit(ShaderAst::CastExpression& node) override;
void Visit(ShaderAst::ConditionalExpression& node) override; void Visit(ShaderAst::ConditionalExpression& node) override;
void Visit(ShaderAst::ConstantValueExpression& node) override; void Visit(ShaderAst::ConstantValueExpression& node) override;
void Visit(ShaderAst::ConstantExpression& node) override; void Visit(ShaderAst::ConstantExpression& node) override;
void Visit(ShaderAst::IdentifierExpression& node) override;
void Visit(ShaderAst::IntrinsicExpression& node) override; void Visit(ShaderAst::IntrinsicExpression& node) override;
void Visit(ShaderAst::StructTypeExpression& node) override; void Visit(ShaderAst::StructTypeExpression& node) override;
void Visit(ShaderAst::SwizzleExpression& node) override; void Visit(ShaderAst::SwizzleExpression& node) override;

View File

@@ -714,10 +714,48 @@ namespace Nz::ShaderAst
ExpressionPtr SanitizeVisitor::Clone(SwizzleExpression& node) ExpressionPtr SanitizeVisitor::Clone(SwizzleExpression& node)
{ {
auto clone = static_unique_pointer_cast<SwizzleExpression>(AstCloner::Clone(node)); auto expression = CloneExpression(MandatoryExpr(node.expression));
Validate(*clone);
return clone; const ExpressionType& exprType = GetExpressionType(*expression);
if (m_context->options.removeScalarSwizzling && IsPrimitiveType(exprType))
{
for (std::size_t i = 0; i < node.componentCount; ++i)
{
if (node.components[i] != 0)
throw AstError{ "invalid swizzle" };
}
if (node.componentCount == 1)
return expression; //< ignore this swizzle (a.x == a)
// Use a Cast expression to replace swizzle
expression = CacheResult(std::move(expression)); //< Since we are going to use a value multiple times, cache it if required
PrimitiveType baseType;
if (IsVectorType(exprType))
baseType = std::get<VectorType>(exprType).type;
else
baseType = std::get<PrimitiveType>(exprType);
auto cast = std::make_unique<CastExpression>();
cast->targetType = ExpressionType{ VectorType{ node.componentCount, baseType } };
for (std::size_t j = 0; j < node.componentCount; ++j)
cast->expressions[j] = CloneExpression(expression);
Validate(*cast);
return cast;
}
else
{
auto clone = std::make_unique<SwizzleExpression>();
clone->componentCount = node.componentCount;
clone->components = node.components;
clone->expression = std::move(expression);
Validate(*clone);
return clone;
}
} }
ExpressionPtr SanitizeVisitor::Clone(UnaryExpression& node) ExpressionPtr SanitizeVisitor::Clone(UnaryExpression& node)
@@ -1787,7 +1825,7 @@ namespace Nz::ShaderAst
ExpressionPtr SanitizeVisitor::CacheResult(ExpressionPtr expression) ExpressionPtr SanitizeVisitor::CacheResult(ExpressionPtr expression)
{ {
// No need to cache LValues (variables/constants) (TODO: Improve this, as constants doesn't need to be cached as well) // No need to cache LValues (variables/constants) (TODO: Improve this, as constants don't need to be cached as well)
if (GetExpressionCategory(*expression) == ExpressionCategory::LValue) if (GetExpressionCategory(*expression) == ExpressionCategory::LValue)
return expression; return expression;
@@ -2940,6 +2978,9 @@ namespace Nz::ShaderAst
std::size_t componentCount; std::size_t componentCount;
if (IsPrimitiveType(exprType)) if (IsPrimitiveType(exprType))
{ {
if (m_context->options.removeScalarSwizzling)
throw AstError{ "internal error" }; //< scalar swizzling should have been removed by then
baseType = std::get<PrimitiveType>(exprType); baseType = std::get<PrimitiveType>(exprType);
componentCount = 1; componentCount = 1;
} }

View File

@@ -309,9 +309,10 @@ namespace Nz
Append(type.GetResultingValue()); Append(type.GetResultingValue());
} }
void GlslWriter::Append(const ShaderAst::FunctionType& /*functionType*/) void GlslWriter::Append(const ShaderAst::FunctionType& functionType)
{ {
throw std::runtime_error("unexpected function type"); const std::string& targetName = Retrieve(m_currentState->previsitor.functions, functionType.funcIndex).name;
Append(targetName);
} }
void GlslWriter::Append(const ShaderAst::IdentifierType& /*identifierType*/) void GlslWriter::Append(const ShaderAst::IdentifierType& /*identifierType*/)
@@ -903,10 +904,9 @@ namespace Nz
void GlslWriter::Visit(ShaderAst::CallFunctionExpression& node) void GlslWriter::Visit(ShaderAst::CallFunctionExpression& node)
{ {
std::size_t functionIndex = std::get<ShaderAst::FunctionType>(GetExpressionType(*node.targetFunction)).funcIndex; node.targetFunction->Visit(*this);
const std::string& targetName = Retrieve(m_currentState->previsitor.functions, functionIndex).name;
Append(targetName, "("); Append("(");
for (std::size_t i = 0; i < node.parameters.size(); ++i) for (std::size_t i = 0; i < node.parameters.size(); ++i)
{ {
if (i != 0) if (i != 0)

View File

@@ -113,11 +113,12 @@ namespace Nz
}; };
const States* states = nullptr; const States* states = nullptr;
ShaderAst::Module* module; const ShaderAst::Module* module;
std::size_t currentModuleIndex; std::size_t currentModuleIndex;
std::stringstream stream; std::stringstream stream;
std::unordered_map<std::size_t, Identifier> aliases; std::unordered_map<std::size_t, Identifier> aliases;
std::unordered_map<std::size_t, Identifier> constants; std::unordered_map<std::size_t, Identifier> constants;
std::unordered_map<std::size_t, Identifier> functions;
std::unordered_map<std::size_t, Identifier> structs; std::unordered_map<std::size_t, Identifier> structs;
std::unordered_map<std::size_t, Identifier> variables; std::unordered_map<std::size_t, Identifier> variables;
std::vector<std::string> moduleNames; std::vector<std::string> moduleNames;
@@ -134,14 +135,13 @@ namespace Nz
m_currentState = nullptr; m_currentState = nullptr;
}); });
ShaderAst::ModulePtr sanitizedModule = ShaderAst::Sanitize(module); state.module = &module;
state.module = sanitizedModule.get();
AppendHeader(); AppendHeader();
// Register imported modules // Register imported modules
m_currentState->currentModuleIndex = 0; m_currentState->currentModuleIndex = 0;
for (const auto& importedModule : sanitizedModule->importedModules) for (const auto& importedModule : module.importedModules)
{ {
AppendAttributes(true, LangVersionAttribute{ importedModule.module->metadata->shaderLangVersion }); AppendAttributes(true, LangVersionAttribute{ importedModule.module->metadata->shaderLangVersion });
AppendAttributes(true, UuidAttribute{ importedModule.module->metadata->moduleId }); AppendAttributes(true, UuidAttribute{ importedModule.module->metadata->moduleId });
@@ -155,7 +155,7 @@ namespace Nz
} }
m_currentState->currentModuleIndex = std::numeric_limits<std::size_t>::max(); m_currentState->currentModuleIndex = std::numeric_limits<std::size_t>::max();
sanitizedModule->rootNode->Visit(*this); module.rootNode->Visit(*this);
return state.stream.str(); return state.stream.str();
} }
@@ -185,17 +185,22 @@ namespace Nz
void LangWriter::Append(const ShaderAst::ExpressionValue<ShaderAst::ExpressionType>& type) void LangWriter::Append(const ShaderAst::ExpressionValue<ShaderAst::ExpressionType>& type)
{ {
Append(type.GetResultingValue()); assert(type.HasValue());
if (type.IsResultingValue())
Append(type.GetResultingValue());
else
type.GetExpression()->Visit(*this);
} }
void LangWriter::Append(const ShaderAst::FunctionType& /*functionType*/) void LangWriter::Append(const ShaderAst::FunctionType& functionType)
{ {
throw std::runtime_error("unexpected function type"); const std::string& targetName = Retrieve(m_currentState->functions, functionType.funcIndex).name;
Append(targetName);
} }
void LangWriter::Append(const ShaderAst::IdentifierType& /*identifierType*/) void LangWriter::Append(const ShaderAst::IdentifierType& identifierType)
{ {
throw std::runtime_error("unexpected identifier type"); Append(identifierType.name);
} }
void LangWriter::Append(const ShaderAst::IntrinsicFunctionType& /*functionType*/) void LangWriter::Append(const ShaderAst::IntrinsicFunctionType& /*functionType*/)
@@ -562,6 +567,8 @@ namespace Nz
} }
else else
attribute.unroll.GetExpression()->Visit(*this); attribute.unroll.GetExpression()->Visit(*this);
Append(")");
} }
void LangWriter::AppendAttribute(UuidAttribute attribute) void LangWriter::AppendAttribute(UuidAttribute attribute)
@@ -681,6 +688,16 @@ namespace Nz
m_currentState->constants.emplace(constantIndex, std::move(identifier)); m_currentState->constants.emplace(constantIndex, std::move(identifier));
} }
void LangWriter::RegisterFunction(std::size_t funcIndex, std::string functionName)
{
State::Identifier identifier;
identifier.moduleIndex = m_currentState->currentModuleIndex;
identifier.name = std::move(functionName);
assert(m_currentState->functions.find(funcIndex) == m_currentState->functions.end());
m_currentState->functions.emplace(funcIndex, std::move(identifier));
}
void LangWriter::RegisterStruct(std::size_t structIndex, std::string structName) void LangWriter::RegisterStruct(std::size_t structIndex, std::string structName)
{ {
State::Identifier identifier; State::Identifier identifier;
@@ -730,9 +747,6 @@ namespace Nz
{ {
Visit(node.expr, true); Visit(node.expr, true);
const ShaderAst::ExpressionType& exprType = ResolveAlias(GetExpressionType(*node.expr));
assert(IsStructType(exprType));
for (const std::string& identifier : node.identifiers) for (const std::string& identifier : node.identifiers)
Append(".", identifier); Append(".", identifier);
} }
@@ -741,9 +755,6 @@ namespace Nz
{ {
Visit(node.expr, true); Visit(node.expr, true);
const ShaderAst::ExpressionType& exprType = ResolveAlias(GetExpressionType(*node.expr));
assert(!IsStructType(exprType));
// Array access // Array access
Append("["); Append("[");
@@ -838,6 +849,21 @@ namespace Nz
Visit(node.right, true); Visit(node.right, true);
} }
void LangWriter::Visit(ShaderAst::CallFunctionExpression& node)
{
node.targetFunction->Visit(*this);
Append("(");
for (std::size_t i = 0; i < node.parameters.size(); ++i)
{
if (i != 0)
Append(", ");
node.parameters[i]->Visit(*this);
}
Append(")");
}
void LangWriter::Visit(ShaderAst::CastExpression& node) void LangWriter::Visit(ShaderAst::CastExpression& node)
{ {
Append(node.targetType); Append(node.targetType);
@@ -880,8 +906,8 @@ namespace Nz
void LangWriter::Visit(ShaderAst::DeclareAliasStatement& node) void LangWriter::Visit(ShaderAst::DeclareAliasStatement& node)
{ {
assert(node.aliasIndex); if (node.aliasIndex)
RegisterAlias(*node.aliasIndex, node.name); RegisterAlias(*node.aliasIndex, node.name);
Append("alias ", node.name, " = "); Append("alias ", node.name, " = ");
assert(node.expression); assert(node.expression);
@@ -891,10 +917,13 @@ namespace Nz
void LangWriter::Visit(ShaderAst::DeclareConstStatement& node) void LangWriter::Visit(ShaderAst::DeclareConstStatement& node)
{ {
assert(node.constIndex); if (node.constIndex)
RegisterConstant(*node.constIndex, node.name); RegisterConstant(*node.constIndex, node.name);
Append("const ", node.name);
if (node.type.HasValue())
Append(": ", node.type);
Append("const ", node.name, ": ", node.type);
if (node.expression) if (node.expression)
{ {
Append(" = "); Append(" = ");
@@ -940,6 +969,11 @@ namespace Nz
AppendIdentifier(m_currentState->constants, node.constantId); AppendIdentifier(m_currentState->constants, node.constantId);
} }
void LangWriter::Visit(ShaderAst::IdentifierExpression& node)
{
Append(node.identifier);
}
void LangWriter::Visit(ShaderAst::DeclareExternalStatement& node) void LangWriter::Visit(ShaderAst::DeclareExternalStatement& node)
{ {
AppendLine("external"); AppendLine("external");
@@ -956,8 +990,8 @@ namespace Nz
AppendAttributes(false, SetAttribute{ externalVar.bindingSet }, BindingAttribute{ externalVar.bindingIndex }); AppendAttributes(false, SetAttribute{ externalVar.bindingSet }, BindingAttribute{ externalVar.bindingIndex });
Append(externalVar.name, ": ", externalVar.type); Append(externalVar.name, ": ", externalVar.type);
assert(externalVar.varIndex); if (externalVar.varIndex)
RegisterVariable(*externalVar.varIndex, externalVar.name); RegisterVariable(*externalVar.varIndex, externalVar.name);
} }
LeaveScope(); LeaveScope();
@@ -967,6 +1001,9 @@ namespace Nz
{ {
NazaraAssert(m_currentState, "This function should only be called while processing an AST"); NazaraAssert(m_currentState, "This function should only be called while processing an AST");
if (node.funcIndex)
RegisterFunction(*node.funcIndex, node.name);
AppendAttributes(true, EntryAttribute{ node.entryStage }, EarlyFragmentTestsAttribute{ node.earlyFragmentTests }, DepthWriteAttribute{ node.depthWrite }); AppendAttributes(true, EntryAttribute{ node.entryStage }, EarlyFragmentTestsAttribute{ node.earlyFragmentTests }, DepthWriteAttribute{ node.depthWrite });
Append("fn ", node.name, "("); Append("fn ", node.name, "(");
for (std::size_t i = 0; i < node.parameters.size(); ++i) for (std::size_t i = 0; i < node.parameters.size(); ++i)
@@ -980,15 +1017,14 @@ namespace Nz
Append(": "); Append(": ");
Append(parameter.type); Append(parameter.type);
assert(parameter.varIndex); if (parameter.varIndex)
RegisterVariable(*parameter.varIndex, parameter.name); RegisterVariable(*parameter.varIndex, parameter.name);
} }
Append(")"); Append(")");
if (node.returnType.HasValue()) if (node.returnType.HasValue())
{ {
const ShaderAst::ExpressionType& returnType = node.returnType.GetResultingValue(); if (!node.returnType.IsResultingValue() || !IsNoType(node.returnType.GetResultingValue()))
if (!IsNoType(returnType)) Append(" -> ", node.returnType);
Append(" -> ", returnType);
} }
AppendLine(); AppendLine();
@@ -1001,10 +1037,13 @@ namespace Nz
void LangWriter::Visit(ShaderAst::DeclareOptionStatement& node) void LangWriter::Visit(ShaderAst::DeclareOptionStatement& node)
{ {
assert(node.optIndex); if (node.optIndex)
RegisterConstant(*node.optIndex, node.optName); RegisterConstant(*node.optIndex, node.optName);
Append("option ", node.optName);
if (node.optType.HasValue())
Append(": ", node.optType);
Append("option ", node.optName, ": ", node.optType);
if (node.defaultValue) if (node.defaultValue)
{ {
Append(" = "); Append(" = ");
@@ -1016,8 +1055,8 @@ namespace Nz
void LangWriter::Visit(ShaderAst::DeclareStructStatement& node) void LangWriter::Visit(ShaderAst::DeclareStructStatement& node)
{ {
assert(node.structIndex); if (node.structIndex)
RegisterStruct(*node.structIndex, node.description.name); RegisterStruct(*node.structIndex, node.description.name);
AppendAttributes(true, LayoutAttribute{ node.description.layout }); AppendAttributes(true, LayoutAttribute{ node.description.layout });
Append("struct "); Append("struct ");
@@ -1041,10 +1080,13 @@ namespace Nz
void LangWriter::Visit(ShaderAst::DeclareVariableStatement& node) void LangWriter::Visit(ShaderAst::DeclareVariableStatement& node)
{ {
assert(node.varIndex); if (node.varIndex)
RegisterVariable(*node.varIndex, node.varName); RegisterVariable(*node.varIndex, node.varName);
Append("let ", node.varName);
if (node.varType.HasValue())
Append(": ", node.varType);
Append("let ", node.varName, ": ", node.varType);
if (node.initialExpression) if (node.initialExpression)
{ {
Append(" = "); Append(" = ");
@@ -1067,8 +1109,8 @@ namespace Nz
void LangWriter::Visit(ShaderAst::ForStatement& node) void LangWriter::Visit(ShaderAst::ForStatement& node)
{ {
assert(node.varIndex); if (node.varIndex)
RegisterVariable(*node.varIndex, node.varName); RegisterVariable(*node.varIndex, node.varName);
AppendAttributes(true, UnrollAttribute{ node.unroll }); AppendAttributes(true, UnrollAttribute{ node.unroll });
Append("for ", node.varName, " in "); Append("for ", node.varName, " in ");
@@ -1089,8 +1131,8 @@ namespace Nz
void LangWriter::Visit(ShaderAst::ForEachStatement& node) void LangWriter::Visit(ShaderAst::ForEachStatement& node)
{ {
assert(node.varIndex); if (node.varIndex)
RegisterVariable(*node.varIndex, node.varName); RegisterVariable(*node.varIndex, node.varName);
AppendAttributes(true, UnrollAttribute{ node.unroll }); AppendAttributes(true, UnrollAttribute{ node.unroll });
Append("for ", node.varName, " in "); Append("for ", node.varName, " in ");
@@ -1102,6 +1144,8 @@ namespace Nz
void LangWriter::Visit(ShaderAst::ImportStatement& node) void LangWriter::Visit(ShaderAst::ImportStatement& node)
{ {
Append("import ");
bool first = true; bool first = true;
for (const std::string& path : node.modulePath) for (const std::string& path : node.modulePath)
{ {
@@ -1112,6 +1156,8 @@ namespace Nz
first = false; first = false;
} }
AppendLine(";");
} }
void LangWriter::Visit(ShaderAst::IntrinsicExpression& node) void LangWriter::Visit(ShaderAst::IntrinsicExpression& node)

View File

@@ -31,6 +31,7 @@ external
)"; )";
Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource); Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource);
shaderModule = SanitizeModule(*shaderModule);
SECTION("Nested AccessMember") SECTION("Nested AccessMember")
{ {

View File

@@ -51,6 +51,7 @@ fn main(input: In) -> FragOut
)"; )";
Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource); Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource);
shaderModule = SanitizeModule(*shaderModule);
ExpectGLSL(*shaderModule, R"( ExpectGLSL(*shaderModule, R"(
void main() void main()

View File

@@ -36,6 +36,7 @@ fn main()
)"; )";
Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource); Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource);
shaderModule = SanitizeModule(*shaderModule);
ExpectGLSL(*shaderModule, R"( ExpectGLSL(*shaderModule, R"(
void main() void main()
@@ -115,6 +116,7 @@ fn main()
)"; )";
Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource); Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource);
shaderModule = SanitizeModule(*shaderModule);
ExpectGLSL(*shaderModule, R"( ExpectGLSL(*shaderModule, R"(
void main() void main()
@@ -189,6 +191,7 @@ fn main()
)"; )";
Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource); Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource);
shaderModule = SanitizeModule(*shaderModule);
ExpectGLSL(*shaderModule, R"( ExpectGLSL(*shaderModule, R"(
void main() void main()

View File

@@ -11,7 +11,7 @@
void ExpectOutput(Nz::ShaderAst::Module& shaderModule, const Nz::ShaderAst::SanitizeVisitor::Options& options, std::string_view expectedOptimizedResult) void ExpectOutput(Nz::ShaderAst::Module& shaderModule, const Nz::ShaderAst::SanitizeVisitor::Options& options, std::string_view expectedOptimizedResult)
{ {
Nz::ShaderAst::ModulePtr sanitizedShader; Nz::ShaderAst::ModulePtr sanitizedShader;
REQUIRE_NOTHROW(sanitizedShader = Nz::ShaderAst::Sanitize(shaderModule, options)); sanitizedShader = SanitizeModule(shaderModule, options);
ExpectNZSL(*sanitizedShader, expectedOptimizedResult); ExpectNZSL(*sanitizedShader, expectedOptimizedResult);
} }

View File

@@ -38,6 +38,7 @@ fn main()
)"; )";
Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource); Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource);
shaderModule = SanitizeModule(*shaderModule);
ExpectGLSL(*shaderModule, R"( ExpectGLSL(*shaderModule, R"(
void main() void main()
@@ -112,7 +113,7 @@ fn main()
)"; )";
Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource); Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource);
shaderModule = SanitizeModule(*shaderModule);
ExpectGLSL(*shaderModule, R"( ExpectGLSL(*shaderModule, R"(
void main() void main()
@@ -190,7 +191,7 @@ fn main()
)"; )";
Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource); Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource);
shaderModule = SanitizeModule(*shaderModule);
ExpectGLSL(*shaderModule, R"( ExpectGLSL(*shaderModule, R"(
void main() void main()
@@ -282,7 +283,7 @@ fn main()
)"; )";
Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource); Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource);
shaderModule = SanitizeModule(*shaderModule);
ExpectGLSL(*shaderModule, R"( ExpectGLSL(*shaderModule, R"(
void main() void main()

View File

@@ -2,6 +2,7 @@
#include <Nazara/Core/File.hpp> #include <Nazara/Core/File.hpp>
#include <Nazara/Core/StringExt.hpp> #include <Nazara/Core/StringExt.hpp>
#include <Nazara/Shader/DirectoryModuleResolver.hpp> #include <Nazara/Shader/DirectoryModuleResolver.hpp>
#include <Nazara/Shader/LangWriter.hpp>
#include <Nazara/Shader/ShaderBuilder.hpp> #include <Nazara/Shader/ShaderBuilder.hpp>
#include <Nazara/Shader/ShaderLangParser.hpp> #include <Nazara/Shader/ShaderLangParser.hpp>
#include <Nazara/Shader/Ast/SanitizeVisitor.hpp> #include <Nazara/Shader/Ast/SanitizeVisitor.hpp>
@@ -74,7 +75,7 @@ fn main(input: InputData) -> OutputData
Nz::ShaderAst::SanitizeVisitor::Options sanitizeOpt; Nz::ShaderAst::SanitizeVisitor::Options sanitizeOpt;
sanitizeOpt.moduleResolver = directoryModuleResolver; sanitizeOpt.moduleResolver = directoryModuleResolver;
REQUIRE_NOTHROW(shaderModule = Nz::ShaderAst::Sanitize(*shaderModule, sanitizeOpt)); shaderModule = SanitizeModule(*shaderModule, sanitizeOpt);
ExpectGLSL(*shaderModule, R"( ExpectGLSL(*shaderModule, R"(
// Module ad3aed6e-0619-4a26-b5ce-abc2ec0836c4 // Module ad3aed6e-0619-4a26-b5ce-abc2ec0836c4
@@ -279,7 +280,7 @@ fn main(input: InputData) -> OutputData
Nz::ShaderAst::SanitizeVisitor::Options sanitizeOpt; Nz::ShaderAst::SanitizeVisitor::Options sanitizeOpt;
sanitizeOpt.moduleResolver = directoryModuleResolver; sanitizeOpt.moduleResolver = directoryModuleResolver;
REQUIRE_NOTHROW(shaderModule = Nz::ShaderAst::Sanitize(*shaderModule, sanitizeOpt)); shaderModule = SanitizeModule(*shaderModule, sanitizeOpt);
ExpectGLSL(*shaderModule, R"( ExpectGLSL(*shaderModule, R"(
// Module ad3aed6e-0619-4a26-b5ce-abc2ec0836c4 // Module ad3aed6e-0619-4a26-b5ce-abc2ec0836c4

View File

@@ -19,7 +19,7 @@ void PropagateConstantAndExpect(std::string_view sourceCode, std::string_view ex
{ {
Nz::ShaderAst::ModulePtr shaderModule; Nz::ShaderAst::ModulePtr shaderModule;
REQUIRE_NOTHROW(shaderModule = Nz::ShaderLang::Parse(sourceCode)); REQUIRE_NOTHROW(shaderModule = Nz::ShaderLang::Parse(sourceCode));
REQUIRE_NOTHROW(shaderModule = Nz::ShaderAst::Sanitize(*shaderModule)); shaderModule = SanitizeModule(*shaderModule);
REQUIRE_NOTHROW(shaderModule = Nz::ShaderAst::PropagateConstants(*shaderModule)); REQUIRE_NOTHROW(shaderModule = Nz::ShaderAst::PropagateConstants(*shaderModule));
ExpectNZSL(*shaderModule, expectedOptimizedResult); ExpectNZSL(*shaderModule, expectedOptimizedResult);
@@ -32,7 +32,7 @@ void EliminateUnusedAndExpect(std::string_view sourceCode, std::string_view expe
Nz::ShaderAst::ModulePtr shaderModule; Nz::ShaderAst::ModulePtr shaderModule;
REQUIRE_NOTHROW(shaderModule = Nz::ShaderLang::Parse(sourceCode)); REQUIRE_NOTHROW(shaderModule = Nz::ShaderLang::Parse(sourceCode));
REQUIRE_NOTHROW(shaderModule = Nz::ShaderAst::Sanitize(*shaderModule)); shaderModule = SanitizeModule(*shaderModule);
REQUIRE_NOTHROW(shaderModule = Nz::ShaderAst::EliminateUnusedPass(*shaderModule, depConfig)); REQUIRE_NOTHROW(shaderModule = Nz::ShaderAst::EliminateUnusedPass(*shaderModule, depConfig));
ExpectNZSL(*shaderModule, expectedOptimizedResult); ExpectNZSL(*shaderModule, expectedOptimizedResult);

View File

@@ -272,3 +272,31 @@ void ExpectSPIRV(const Nz::ShaderAst::Module& shaderModule, std::string_view exp
} }
} }
} }
Nz::ShaderAst::ModulePtr SanitizeModule(const Nz::ShaderAst::Module& module)
{
Nz::ShaderAst::SanitizeVisitor::Options defaultOptions;
return SanitizeModule(module, defaultOptions);
}
Nz::ShaderAst::ModulePtr SanitizeModule(const Nz::ShaderAst::Module& module, const Nz::ShaderAst::SanitizeVisitor::Options& options)
{
Nz::ShaderAst::ModulePtr shaderModule;
WHEN("We sanitize the shader")
{
REQUIRE_NOTHROW(shaderModule = Nz::ShaderAst::Sanitize(module, options));
}
WHEN("We output NZSL and try to parse it again")
{
Nz::LangWriter langWriter;
std::string outputCode = langWriter.Generate((shaderModule) ? *shaderModule : module);
REQUIRE_NOTHROW(shaderModule = Nz::ShaderAst::Sanitize(*Nz::ShaderLang::Parse(outputCode), options));
}
// Ensure sanitization
if (!shaderModule)
REQUIRE_NOTHROW(shaderModule = Nz::ShaderAst::Sanitize(module, options));
return shaderModule;
}

View File

@@ -4,10 +4,14 @@
#define NAZARA_UNITTESTS_SHADER_SHADERUTILS_HPP #define NAZARA_UNITTESTS_SHADER_SHADERUTILS_HPP
#include <Nazara/Shader/Ast/Module.hpp> #include <Nazara/Shader/Ast/Module.hpp>
#include <Nazara/Shader/Ast/SanitizeVisitor.hpp>
#include <string> #include <string>
void ExpectGLSL(const Nz::ShaderAst::Module& shader, std::string_view expectedOutput); void ExpectGLSL(const Nz::ShaderAst::Module& shader, std::string_view expectedOutput);
void ExpectNZSL(const Nz::ShaderAst::Module& shader, std::string_view expectedOutput); void ExpectNZSL(const Nz::ShaderAst::Module& shader, std::string_view expectedOutput);
void ExpectSPIRV(const Nz::ShaderAst::Module& shader, std::string_view expectedOutput, bool outputParameter = false); void ExpectSPIRV(const Nz::ShaderAst::Module& shader, std::string_view expectedOutput, bool outputParameter = false);
Nz::ShaderAst::ModulePtr SanitizeModule(const Nz::ShaderAst::Module& module);
Nz::ShaderAst::ModulePtr SanitizeModule(const Nz::ShaderAst::Module& module, const Nz::ShaderAst::SanitizeVisitor::Options& options);
#endif #endif

View File

@@ -25,6 +25,7 @@ fn main()
)"; )";
Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource); Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource);
shaderModule = SanitizeModule(*shaderModule);
ExpectGLSL(*shaderModule, R"( ExpectGLSL(*shaderModule, R"(
void main() void main()
@@ -72,6 +73,7 @@ fn main()
)"; )";
Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource); Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource);
shaderModule = SanitizeModule(*shaderModule);
ExpectGLSL(*shaderModule, R"( ExpectGLSL(*shaderModule, R"(
void main() void main()
@@ -122,6 +124,7 @@ fn main()
)"; )";
Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource); Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource);
shaderModule = SanitizeModule(*shaderModule);
ExpectGLSL(*shaderModule, R"( ExpectGLSL(*shaderModule, R"(
void main() void main()
@@ -168,6 +171,7 @@ fn main()
)"; )";
Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource); Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource);
shaderModule = SanitizeModule(*shaderModule);
ExpectGLSL(*shaderModule, R"( ExpectGLSL(*shaderModule, R"(
void main() void main()
@@ -221,6 +225,7 @@ fn main()
)"; )";
Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource); Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource);
shaderModule = SanitizeModule(*shaderModule);
ExpectGLSL(*shaderModule, R"( ExpectGLSL(*shaderModule, R"(
void main() void main()
@@ -272,6 +277,7 @@ fn main()
)"; )";
Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource); Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource);
shaderModule = SanitizeModule(*shaderModule);
ExpectGLSL(*shaderModule, R"( ExpectGLSL(*shaderModule, R"(
void main() void main()