Shader/LangWriter: Make LangWriter able to output AST before sanitization as well
This commit is contained in:
parent
2f26a1d9c7
commit
80f9556f8c
|
|
@ -95,6 +95,7 @@ namespace Nz
|
|||
|
||||
void RegisterAlias(std::size_t aliasIndex, std::string aliasName);
|
||||
void RegisterConstant(std::size_t constantIndex, std::string constantName);
|
||||
void RegisterFunction(std::size_t funcIndex, std::string functionName);
|
||||
void RegisterStruct(std::size_t structIndex, std::string structName);
|
||||
void RegisterVariable(std::size_t varIndex, std::string varName);
|
||||
|
||||
|
|
@ -107,10 +108,12 @@ namespace Nz
|
|||
void Visit(ShaderAst::AliasValueExpression& node) override;
|
||||
void Visit(ShaderAst::AssignExpression& node) override;
|
||||
void Visit(ShaderAst::BinaryExpression& node) override;
|
||||
void Visit(ShaderAst::CallFunctionExpression& node) override;
|
||||
void Visit(ShaderAst::CastExpression& node) override;
|
||||
void Visit(ShaderAst::ConditionalExpression& node) override;
|
||||
void Visit(ShaderAst::ConstantValueExpression& node) override;
|
||||
void Visit(ShaderAst::ConstantExpression& node) override;
|
||||
void Visit(ShaderAst::IdentifierExpression& node) override;
|
||||
void Visit(ShaderAst::IntrinsicExpression& node) override;
|
||||
void Visit(ShaderAst::StructTypeExpression& node) override;
|
||||
void Visit(ShaderAst::SwizzleExpression& node) override;
|
||||
|
|
|
|||
|
|
@ -714,10 +714,48 @@ namespace Nz::ShaderAst
|
|||
|
||||
ExpressionPtr SanitizeVisitor::Clone(SwizzleExpression& node)
|
||||
{
|
||||
auto clone = static_unique_pointer_cast<SwizzleExpression>(AstCloner::Clone(node));
|
||||
Validate(*clone);
|
||||
auto expression = CloneExpression(MandatoryExpr(node.expression));
|
||||
|
||||
return clone;
|
||||
const ExpressionType& exprType = GetExpressionType(*expression);
|
||||
if (m_context->options.removeScalarSwizzling && IsPrimitiveType(exprType))
|
||||
{
|
||||
for (std::size_t i = 0; i < node.componentCount; ++i)
|
||||
{
|
||||
if (node.components[i] != 0)
|
||||
throw AstError{ "invalid swizzle" };
|
||||
|
||||
}
|
||||
if (node.componentCount == 1)
|
||||
return expression; //< ignore this swizzle (a.x == a)
|
||||
|
||||
// Use a Cast expression to replace swizzle
|
||||
expression = CacheResult(std::move(expression)); //< Since we are going to use a value multiple times, cache it if required
|
||||
|
||||
PrimitiveType baseType;
|
||||
if (IsVectorType(exprType))
|
||||
baseType = std::get<VectorType>(exprType).type;
|
||||
else
|
||||
baseType = std::get<PrimitiveType>(exprType);
|
||||
|
||||
auto cast = std::make_unique<CastExpression>();
|
||||
cast->targetType = ExpressionType{ VectorType{ node.componentCount, baseType } };
|
||||
for (std::size_t j = 0; j < node.componentCount; ++j)
|
||||
cast->expressions[j] = CloneExpression(expression);
|
||||
|
||||
Validate(*cast);
|
||||
|
||||
return cast;
|
||||
}
|
||||
else
|
||||
{
|
||||
auto clone = std::make_unique<SwizzleExpression>();
|
||||
clone->componentCount = node.componentCount;
|
||||
clone->components = node.components;
|
||||
clone->expression = std::move(expression);
|
||||
Validate(*clone);
|
||||
|
||||
return clone;
|
||||
}
|
||||
}
|
||||
|
||||
ExpressionPtr SanitizeVisitor::Clone(UnaryExpression& node)
|
||||
|
|
@ -1787,7 +1825,7 @@ namespace Nz::ShaderAst
|
|||
|
||||
ExpressionPtr SanitizeVisitor::CacheResult(ExpressionPtr expression)
|
||||
{
|
||||
// No need to cache LValues (variables/constants) (TODO: Improve this, as constants doesn't need to be cached as well)
|
||||
// No need to cache LValues (variables/constants) (TODO: Improve this, as constants don't need to be cached as well)
|
||||
if (GetExpressionCategory(*expression) == ExpressionCategory::LValue)
|
||||
return expression;
|
||||
|
||||
|
|
@ -2940,6 +2978,9 @@ namespace Nz::ShaderAst
|
|||
std::size_t componentCount;
|
||||
if (IsPrimitiveType(exprType))
|
||||
{
|
||||
if (m_context->options.removeScalarSwizzling)
|
||||
throw AstError{ "internal error" }; //< scalar swizzling should have been removed by then
|
||||
|
||||
baseType = std::get<PrimitiveType>(exprType);
|
||||
componentCount = 1;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -309,9 +309,10 @@ namespace Nz
|
|||
Append(type.GetResultingValue());
|
||||
}
|
||||
|
||||
void GlslWriter::Append(const ShaderAst::FunctionType& /*functionType*/)
|
||||
void GlslWriter::Append(const ShaderAst::FunctionType& functionType)
|
||||
{
|
||||
throw std::runtime_error("unexpected function type");
|
||||
const std::string& targetName = Retrieve(m_currentState->previsitor.functions, functionType.funcIndex).name;
|
||||
Append(targetName);
|
||||
}
|
||||
|
||||
void GlslWriter::Append(const ShaderAst::IdentifierType& /*identifierType*/)
|
||||
|
|
@ -903,10 +904,9 @@ namespace Nz
|
|||
|
||||
void GlslWriter::Visit(ShaderAst::CallFunctionExpression& node)
|
||||
{
|
||||
std::size_t functionIndex = std::get<ShaderAst::FunctionType>(GetExpressionType(*node.targetFunction)).funcIndex;
|
||||
const std::string& targetName = Retrieve(m_currentState->previsitor.functions, functionIndex).name;
|
||||
node.targetFunction->Visit(*this);
|
||||
|
||||
Append(targetName, "(");
|
||||
Append("(");
|
||||
for (std::size_t i = 0; i < node.parameters.size(); ++i)
|
||||
{
|
||||
if (i != 0)
|
||||
|
|
|
|||
|
|
@ -113,11 +113,12 @@ namespace Nz
|
|||
};
|
||||
|
||||
const States* states = nullptr;
|
||||
ShaderAst::Module* module;
|
||||
const ShaderAst::Module* module;
|
||||
std::size_t currentModuleIndex;
|
||||
std::stringstream stream;
|
||||
std::unordered_map<std::size_t, Identifier> aliases;
|
||||
std::unordered_map<std::size_t, Identifier> constants;
|
||||
std::unordered_map<std::size_t, Identifier> functions;
|
||||
std::unordered_map<std::size_t, Identifier> structs;
|
||||
std::unordered_map<std::size_t, Identifier> variables;
|
||||
std::vector<std::string> moduleNames;
|
||||
|
|
@ -134,14 +135,13 @@ namespace Nz
|
|||
m_currentState = nullptr;
|
||||
});
|
||||
|
||||
ShaderAst::ModulePtr sanitizedModule = ShaderAst::Sanitize(module);
|
||||
state.module = sanitizedModule.get();
|
||||
state.module = &module;
|
||||
|
||||
AppendHeader();
|
||||
|
||||
// Register imported modules
|
||||
m_currentState->currentModuleIndex = 0;
|
||||
for (const auto& importedModule : sanitizedModule->importedModules)
|
||||
for (const auto& importedModule : module.importedModules)
|
||||
{
|
||||
AppendAttributes(true, LangVersionAttribute{ importedModule.module->metadata->shaderLangVersion });
|
||||
AppendAttributes(true, UuidAttribute{ importedModule.module->metadata->moduleId });
|
||||
|
|
@ -155,7 +155,7 @@ namespace Nz
|
|||
}
|
||||
|
||||
m_currentState->currentModuleIndex = std::numeric_limits<std::size_t>::max();
|
||||
sanitizedModule->rootNode->Visit(*this);
|
||||
module.rootNode->Visit(*this);
|
||||
|
||||
return state.stream.str();
|
||||
}
|
||||
|
|
@ -185,17 +185,22 @@ namespace Nz
|
|||
|
||||
void LangWriter::Append(const ShaderAst::ExpressionValue<ShaderAst::ExpressionType>& type)
|
||||
{
|
||||
Append(type.GetResultingValue());
|
||||
assert(type.HasValue());
|
||||
if (type.IsResultingValue())
|
||||
Append(type.GetResultingValue());
|
||||
else
|
||||
type.GetExpression()->Visit(*this);
|
||||
}
|
||||
|
||||
void LangWriter::Append(const ShaderAst::FunctionType& /*functionType*/)
|
||||
void LangWriter::Append(const ShaderAst::FunctionType& functionType)
|
||||
{
|
||||
throw std::runtime_error("unexpected function type");
|
||||
const std::string& targetName = Retrieve(m_currentState->functions, functionType.funcIndex).name;
|
||||
Append(targetName);
|
||||
}
|
||||
|
||||
void LangWriter::Append(const ShaderAst::IdentifierType& /*identifierType*/)
|
||||
void LangWriter::Append(const ShaderAst::IdentifierType& identifierType)
|
||||
{
|
||||
throw std::runtime_error("unexpected identifier type");
|
||||
Append(identifierType.name);
|
||||
}
|
||||
|
||||
void LangWriter::Append(const ShaderAst::IntrinsicFunctionType& /*functionType*/)
|
||||
|
|
@ -562,6 +567,8 @@ namespace Nz
|
|||
}
|
||||
else
|
||||
attribute.unroll.GetExpression()->Visit(*this);
|
||||
|
||||
Append(")");
|
||||
}
|
||||
|
||||
void LangWriter::AppendAttribute(UuidAttribute attribute)
|
||||
|
|
@ -681,6 +688,16 @@ namespace Nz
|
|||
m_currentState->constants.emplace(constantIndex, std::move(identifier));
|
||||
}
|
||||
|
||||
void LangWriter::RegisterFunction(std::size_t funcIndex, std::string functionName)
|
||||
{
|
||||
State::Identifier identifier;
|
||||
identifier.moduleIndex = m_currentState->currentModuleIndex;
|
||||
identifier.name = std::move(functionName);
|
||||
|
||||
assert(m_currentState->functions.find(funcIndex) == m_currentState->functions.end());
|
||||
m_currentState->functions.emplace(funcIndex, std::move(identifier));
|
||||
}
|
||||
|
||||
void LangWriter::RegisterStruct(std::size_t structIndex, std::string structName)
|
||||
{
|
||||
State::Identifier identifier;
|
||||
|
|
@ -730,9 +747,6 @@ namespace Nz
|
|||
{
|
||||
Visit(node.expr, true);
|
||||
|
||||
const ShaderAst::ExpressionType& exprType = ResolveAlias(GetExpressionType(*node.expr));
|
||||
assert(IsStructType(exprType));
|
||||
|
||||
for (const std::string& identifier : node.identifiers)
|
||||
Append(".", identifier);
|
||||
}
|
||||
|
|
@ -741,9 +755,6 @@ namespace Nz
|
|||
{
|
||||
Visit(node.expr, true);
|
||||
|
||||
const ShaderAst::ExpressionType& exprType = ResolveAlias(GetExpressionType(*node.expr));
|
||||
assert(!IsStructType(exprType));
|
||||
|
||||
// Array access
|
||||
Append("[");
|
||||
|
||||
|
|
@ -838,6 +849,21 @@ namespace Nz
|
|||
Visit(node.right, true);
|
||||
}
|
||||
|
||||
void LangWriter::Visit(ShaderAst::CallFunctionExpression& node)
|
||||
{
|
||||
node.targetFunction->Visit(*this);
|
||||
|
||||
Append("(");
|
||||
for (std::size_t i = 0; i < node.parameters.size(); ++i)
|
||||
{
|
||||
if (i != 0)
|
||||
Append(", ");
|
||||
|
||||
node.parameters[i]->Visit(*this);
|
||||
}
|
||||
Append(")");
|
||||
}
|
||||
|
||||
void LangWriter::Visit(ShaderAst::CastExpression& node)
|
||||
{
|
||||
Append(node.targetType);
|
||||
|
|
@ -880,8 +906,8 @@ namespace Nz
|
|||
|
||||
void LangWriter::Visit(ShaderAst::DeclareAliasStatement& node)
|
||||
{
|
||||
assert(node.aliasIndex);
|
||||
RegisterAlias(*node.aliasIndex, node.name);
|
||||
if (node.aliasIndex)
|
||||
RegisterAlias(*node.aliasIndex, node.name);
|
||||
|
||||
Append("alias ", node.name, " = ");
|
||||
assert(node.expression);
|
||||
|
|
@ -891,10 +917,13 @@ namespace Nz
|
|||
|
||||
void LangWriter::Visit(ShaderAst::DeclareConstStatement& node)
|
||||
{
|
||||
assert(node.constIndex);
|
||||
RegisterConstant(*node.constIndex, node.name);
|
||||
if (node.constIndex)
|
||||
RegisterConstant(*node.constIndex, node.name);
|
||||
|
||||
Append("const ", node.name);
|
||||
if (node.type.HasValue())
|
||||
Append(": ", node.type);
|
||||
|
||||
Append("const ", node.name, ": ", node.type);
|
||||
if (node.expression)
|
||||
{
|
||||
Append(" = ");
|
||||
|
|
@ -940,6 +969,11 @@ namespace Nz
|
|||
AppendIdentifier(m_currentState->constants, node.constantId);
|
||||
}
|
||||
|
||||
void LangWriter::Visit(ShaderAst::IdentifierExpression& node)
|
||||
{
|
||||
Append(node.identifier);
|
||||
}
|
||||
|
||||
void LangWriter::Visit(ShaderAst::DeclareExternalStatement& node)
|
||||
{
|
||||
AppendLine("external");
|
||||
|
|
@ -956,8 +990,8 @@ namespace Nz
|
|||
AppendAttributes(false, SetAttribute{ externalVar.bindingSet }, BindingAttribute{ externalVar.bindingIndex });
|
||||
Append(externalVar.name, ": ", externalVar.type);
|
||||
|
||||
assert(externalVar.varIndex);
|
||||
RegisterVariable(*externalVar.varIndex, externalVar.name);
|
||||
if (externalVar.varIndex)
|
||||
RegisterVariable(*externalVar.varIndex, externalVar.name);
|
||||
}
|
||||
|
||||
LeaveScope();
|
||||
|
|
@ -967,6 +1001,9 @@ namespace Nz
|
|||
{
|
||||
NazaraAssert(m_currentState, "This function should only be called while processing an AST");
|
||||
|
||||
if (node.funcIndex)
|
||||
RegisterFunction(*node.funcIndex, node.name);
|
||||
|
||||
AppendAttributes(true, EntryAttribute{ node.entryStage }, EarlyFragmentTestsAttribute{ node.earlyFragmentTests }, DepthWriteAttribute{ node.depthWrite });
|
||||
Append("fn ", node.name, "(");
|
||||
for (std::size_t i = 0; i < node.parameters.size(); ++i)
|
||||
|
|
@ -980,15 +1017,14 @@ namespace Nz
|
|||
Append(": ");
|
||||
Append(parameter.type);
|
||||
|
||||
assert(parameter.varIndex);
|
||||
RegisterVariable(*parameter.varIndex, parameter.name);
|
||||
if (parameter.varIndex)
|
||||
RegisterVariable(*parameter.varIndex, parameter.name);
|
||||
}
|
||||
Append(")");
|
||||
if (node.returnType.HasValue())
|
||||
{
|
||||
const ShaderAst::ExpressionType& returnType = node.returnType.GetResultingValue();
|
||||
if (!IsNoType(returnType))
|
||||
Append(" -> ", returnType);
|
||||
if (!node.returnType.IsResultingValue() || !IsNoType(node.returnType.GetResultingValue()))
|
||||
Append(" -> ", node.returnType);
|
||||
}
|
||||
|
||||
AppendLine();
|
||||
|
|
@ -1001,10 +1037,13 @@ namespace Nz
|
|||
|
||||
void LangWriter::Visit(ShaderAst::DeclareOptionStatement& node)
|
||||
{
|
||||
assert(node.optIndex);
|
||||
RegisterConstant(*node.optIndex, node.optName);
|
||||
if (node.optIndex)
|
||||
RegisterConstant(*node.optIndex, node.optName);
|
||||
|
||||
Append("option ", node.optName);
|
||||
if (node.optType.HasValue())
|
||||
Append(": ", node.optType);
|
||||
|
||||
Append("option ", node.optName, ": ", node.optType);
|
||||
if (node.defaultValue)
|
||||
{
|
||||
Append(" = ");
|
||||
|
|
@ -1016,8 +1055,8 @@ namespace Nz
|
|||
|
||||
void LangWriter::Visit(ShaderAst::DeclareStructStatement& node)
|
||||
{
|
||||
assert(node.structIndex);
|
||||
RegisterStruct(*node.structIndex, node.description.name);
|
||||
if (node.structIndex)
|
||||
RegisterStruct(*node.structIndex, node.description.name);
|
||||
|
||||
AppendAttributes(true, LayoutAttribute{ node.description.layout });
|
||||
Append("struct ");
|
||||
|
|
@ -1041,10 +1080,13 @@ namespace Nz
|
|||
|
||||
void LangWriter::Visit(ShaderAst::DeclareVariableStatement& node)
|
||||
{
|
||||
assert(node.varIndex);
|
||||
RegisterVariable(*node.varIndex, node.varName);
|
||||
if (node.varIndex)
|
||||
RegisterVariable(*node.varIndex, node.varName);
|
||||
|
||||
Append("let ", node.varName);
|
||||
if (node.varType.HasValue())
|
||||
Append(": ", node.varType);
|
||||
|
||||
Append("let ", node.varName, ": ", node.varType);
|
||||
if (node.initialExpression)
|
||||
{
|
||||
Append(" = ");
|
||||
|
|
@ -1067,8 +1109,8 @@ namespace Nz
|
|||
|
||||
void LangWriter::Visit(ShaderAst::ForStatement& node)
|
||||
{
|
||||
assert(node.varIndex);
|
||||
RegisterVariable(*node.varIndex, node.varName);
|
||||
if (node.varIndex)
|
||||
RegisterVariable(*node.varIndex, node.varName);
|
||||
|
||||
AppendAttributes(true, UnrollAttribute{ node.unroll });
|
||||
Append("for ", node.varName, " in ");
|
||||
|
|
@ -1089,8 +1131,8 @@ namespace Nz
|
|||
|
||||
void LangWriter::Visit(ShaderAst::ForEachStatement& node)
|
||||
{
|
||||
assert(node.varIndex);
|
||||
RegisterVariable(*node.varIndex, node.varName);
|
||||
if (node.varIndex)
|
||||
RegisterVariable(*node.varIndex, node.varName);
|
||||
|
||||
AppendAttributes(true, UnrollAttribute{ node.unroll });
|
||||
Append("for ", node.varName, " in ");
|
||||
|
|
@ -1102,6 +1144,8 @@ namespace Nz
|
|||
|
||||
void LangWriter::Visit(ShaderAst::ImportStatement& node)
|
||||
{
|
||||
Append("import ");
|
||||
|
||||
bool first = true;
|
||||
for (const std::string& path : node.modulePath)
|
||||
{
|
||||
|
|
@ -1112,6 +1156,8 @@ namespace Nz
|
|||
|
||||
first = false;
|
||||
}
|
||||
|
||||
AppendLine(";");
|
||||
}
|
||||
|
||||
void LangWriter::Visit(ShaderAst::IntrinsicExpression& node)
|
||||
|
|
|
|||
|
|
@ -31,6 +31,7 @@ external
|
|||
)";
|
||||
|
||||
Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource);
|
||||
shaderModule = SanitizeModule(*shaderModule);
|
||||
|
||||
SECTION("Nested AccessMember")
|
||||
{
|
||||
|
|
|
|||
|
|
@ -51,6 +51,7 @@ fn main(input: In) -> FragOut
|
|||
)";
|
||||
|
||||
Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource);
|
||||
shaderModule = SanitizeModule(*shaderModule);
|
||||
|
||||
ExpectGLSL(*shaderModule, R"(
|
||||
void main()
|
||||
|
|
|
|||
|
|
@ -36,6 +36,7 @@ fn main()
|
|||
)";
|
||||
|
||||
Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource);
|
||||
shaderModule = SanitizeModule(*shaderModule);
|
||||
|
||||
ExpectGLSL(*shaderModule, R"(
|
||||
void main()
|
||||
|
|
@ -115,6 +116,7 @@ fn main()
|
|||
)";
|
||||
|
||||
Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource);
|
||||
shaderModule = SanitizeModule(*shaderModule);
|
||||
|
||||
ExpectGLSL(*shaderModule, R"(
|
||||
void main()
|
||||
|
|
@ -189,6 +191,7 @@ fn main()
|
|||
)";
|
||||
|
||||
Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource);
|
||||
shaderModule = SanitizeModule(*shaderModule);
|
||||
|
||||
ExpectGLSL(*shaderModule, R"(
|
||||
void main()
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@
|
|||
void ExpectOutput(Nz::ShaderAst::Module& shaderModule, const Nz::ShaderAst::SanitizeVisitor::Options& options, std::string_view expectedOptimizedResult)
|
||||
{
|
||||
Nz::ShaderAst::ModulePtr sanitizedShader;
|
||||
REQUIRE_NOTHROW(sanitizedShader = Nz::ShaderAst::Sanitize(shaderModule, options));
|
||||
sanitizedShader = SanitizeModule(shaderModule, options);
|
||||
|
||||
ExpectNZSL(*sanitizedShader, expectedOptimizedResult);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -38,6 +38,7 @@ fn main()
|
|||
)";
|
||||
|
||||
Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource);
|
||||
shaderModule = SanitizeModule(*shaderModule);
|
||||
|
||||
ExpectGLSL(*shaderModule, R"(
|
||||
void main()
|
||||
|
|
@ -112,7 +113,7 @@ fn main()
|
|||
)";
|
||||
|
||||
Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource);
|
||||
|
||||
shaderModule = SanitizeModule(*shaderModule);
|
||||
|
||||
ExpectGLSL(*shaderModule, R"(
|
||||
void main()
|
||||
|
|
@ -190,7 +191,7 @@ fn main()
|
|||
)";
|
||||
|
||||
Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource);
|
||||
|
||||
shaderModule = SanitizeModule(*shaderModule);
|
||||
|
||||
ExpectGLSL(*shaderModule, R"(
|
||||
void main()
|
||||
|
|
@ -282,7 +283,7 @@ fn main()
|
|||
)";
|
||||
|
||||
Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource);
|
||||
|
||||
shaderModule = SanitizeModule(*shaderModule);
|
||||
|
||||
ExpectGLSL(*shaderModule, R"(
|
||||
void main()
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
#include <Nazara/Core/File.hpp>
|
||||
#include <Nazara/Core/StringExt.hpp>
|
||||
#include <Nazara/Shader/DirectoryModuleResolver.hpp>
|
||||
#include <Nazara/Shader/LangWriter.hpp>
|
||||
#include <Nazara/Shader/ShaderBuilder.hpp>
|
||||
#include <Nazara/Shader/ShaderLangParser.hpp>
|
||||
#include <Nazara/Shader/Ast/SanitizeVisitor.hpp>
|
||||
|
|
@ -74,7 +75,7 @@ fn main(input: InputData) -> OutputData
|
|||
Nz::ShaderAst::SanitizeVisitor::Options sanitizeOpt;
|
||||
sanitizeOpt.moduleResolver = directoryModuleResolver;
|
||||
|
||||
REQUIRE_NOTHROW(shaderModule = Nz::ShaderAst::Sanitize(*shaderModule, sanitizeOpt));
|
||||
shaderModule = SanitizeModule(*shaderModule, sanitizeOpt);
|
||||
|
||||
ExpectGLSL(*shaderModule, R"(
|
||||
// Module ad3aed6e-0619-4a26-b5ce-abc2ec0836c4
|
||||
|
|
@ -279,7 +280,7 @@ fn main(input: InputData) -> OutputData
|
|||
Nz::ShaderAst::SanitizeVisitor::Options sanitizeOpt;
|
||||
sanitizeOpt.moduleResolver = directoryModuleResolver;
|
||||
|
||||
REQUIRE_NOTHROW(shaderModule = Nz::ShaderAst::Sanitize(*shaderModule, sanitizeOpt));
|
||||
shaderModule = SanitizeModule(*shaderModule, sanitizeOpt);
|
||||
|
||||
ExpectGLSL(*shaderModule, R"(
|
||||
// Module ad3aed6e-0619-4a26-b5ce-abc2ec0836c4
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ void PropagateConstantAndExpect(std::string_view sourceCode, std::string_view ex
|
|||
{
|
||||
Nz::ShaderAst::ModulePtr shaderModule;
|
||||
REQUIRE_NOTHROW(shaderModule = Nz::ShaderLang::Parse(sourceCode));
|
||||
REQUIRE_NOTHROW(shaderModule = Nz::ShaderAst::Sanitize(*shaderModule));
|
||||
shaderModule = SanitizeModule(*shaderModule);
|
||||
REQUIRE_NOTHROW(shaderModule = Nz::ShaderAst::PropagateConstants(*shaderModule));
|
||||
|
||||
ExpectNZSL(*shaderModule, expectedOptimizedResult);
|
||||
|
|
@ -32,7 +32,7 @@ void EliminateUnusedAndExpect(std::string_view sourceCode, std::string_view expe
|
|||
|
||||
Nz::ShaderAst::ModulePtr shaderModule;
|
||||
REQUIRE_NOTHROW(shaderModule = Nz::ShaderLang::Parse(sourceCode));
|
||||
REQUIRE_NOTHROW(shaderModule = Nz::ShaderAst::Sanitize(*shaderModule));
|
||||
shaderModule = SanitizeModule(*shaderModule);
|
||||
REQUIRE_NOTHROW(shaderModule = Nz::ShaderAst::EliminateUnusedPass(*shaderModule, depConfig));
|
||||
|
||||
ExpectNZSL(*shaderModule, expectedOptimizedResult);
|
||||
|
|
|
|||
|
|
@ -272,3 +272,31 @@ void ExpectSPIRV(const Nz::ShaderAst::Module& shaderModule, std::string_view exp
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
Nz::ShaderAst::ModulePtr SanitizeModule(const Nz::ShaderAst::Module& module)
|
||||
{
|
||||
Nz::ShaderAst::SanitizeVisitor::Options defaultOptions;
|
||||
return SanitizeModule(module, defaultOptions);
|
||||
}
|
||||
|
||||
Nz::ShaderAst::ModulePtr SanitizeModule(const Nz::ShaderAst::Module& module, const Nz::ShaderAst::SanitizeVisitor::Options& options)
|
||||
{
|
||||
Nz::ShaderAst::ModulePtr shaderModule;
|
||||
WHEN("We sanitize the shader")
|
||||
{
|
||||
REQUIRE_NOTHROW(shaderModule = Nz::ShaderAst::Sanitize(module, options));
|
||||
}
|
||||
|
||||
WHEN("We output NZSL and try to parse it again")
|
||||
{
|
||||
Nz::LangWriter langWriter;
|
||||
std::string outputCode = langWriter.Generate((shaderModule) ? *shaderModule : module);
|
||||
REQUIRE_NOTHROW(shaderModule = Nz::ShaderAst::Sanitize(*Nz::ShaderLang::Parse(outputCode), options));
|
||||
}
|
||||
|
||||
// Ensure sanitization
|
||||
if (!shaderModule)
|
||||
REQUIRE_NOTHROW(shaderModule = Nz::ShaderAst::Sanitize(module, options));
|
||||
|
||||
return shaderModule;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,10 +4,14 @@
|
|||
#define NAZARA_UNITTESTS_SHADER_SHADERUTILS_HPP
|
||||
|
||||
#include <Nazara/Shader/Ast/Module.hpp>
|
||||
#include <Nazara/Shader/Ast/SanitizeVisitor.hpp>
|
||||
#include <string>
|
||||
|
||||
void ExpectGLSL(const Nz::ShaderAst::Module& shader, std::string_view expectedOutput);
|
||||
void ExpectNZSL(const Nz::ShaderAst::Module& shader, std::string_view expectedOutput);
|
||||
void ExpectSPIRV(const Nz::ShaderAst::Module& shader, std::string_view expectedOutput, bool outputParameter = false);
|
||||
|
||||
Nz::ShaderAst::ModulePtr SanitizeModule(const Nz::ShaderAst::Module& module);
|
||||
Nz::ShaderAst::ModulePtr SanitizeModule(const Nz::ShaderAst::Module& module, const Nz::ShaderAst::SanitizeVisitor::Options& options);
|
||||
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@ fn main()
|
|||
)";
|
||||
|
||||
Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource);
|
||||
shaderModule = SanitizeModule(*shaderModule);
|
||||
|
||||
ExpectGLSL(*shaderModule, R"(
|
||||
void main()
|
||||
|
|
@ -72,6 +73,7 @@ fn main()
|
|||
)";
|
||||
|
||||
Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource);
|
||||
shaderModule = SanitizeModule(*shaderModule);
|
||||
|
||||
ExpectGLSL(*shaderModule, R"(
|
||||
void main()
|
||||
|
|
@ -122,6 +124,7 @@ fn main()
|
|||
)";
|
||||
|
||||
Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource);
|
||||
shaderModule = SanitizeModule(*shaderModule);
|
||||
|
||||
ExpectGLSL(*shaderModule, R"(
|
||||
void main()
|
||||
|
|
@ -168,6 +171,7 @@ fn main()
|
|||
)";
|
||||
|
||||
Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource);
|
||||
shaderModule = SanitizeModule(*shaderModule);
|
||||
|
||||
ExpectGLSL(*shaderModule, R"(
|
||||
void main()
|
||||
|
|
@ -221,6 +225,7 @@ fn main()
|
|||
)";
|
||||
|
||||
Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource);
|
||||
shaderModule = SanitizeModule(*shaderModule);
|
||||
|
||||
ExpectGLSL(*shaderModule, R"(
|
||||
void main()
|
||||
|
|
@ -272,6 +277,7 @@ fn main()
|
|||
)";
|
||||
|
||||
Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource);
|
||||
shaderModule = SanitizeModule(*shaderModule);
|
||||
|
||||
ExpectGLSL(*shaderModule, R"(
|
||||
void main()
|
||||
|
|
|
|||
Loading…
Reference in New Issue