Shader/LangWriter: Make LangWriter able to output AST before sanitization as well

This commit is contained in:
Jérôme Leclercq 2022-03-12 18:16:30 +01:00
parent 2f26a1d9c7
commit 80f9556f8c
14 changed files with 192 additions and 57 deletions

View File

@ -95,6 +95,7 @@ namespace Nz
void RegisterAlias(std::size_t aliasIndex, std::string aliasName);
void RegisterConstant(std::size_t constantIndex, std::string constantName);
void RegisterFunction(std::size_t funcIndex, std::string functionName);
void RegisterStruct(std::size_t structIndex, std::string structName);
void RegisterVariable(std::size_t varIndex, std::string varName);
@ -107,10 +108,12 @@ namespace Nz
void Visit(ShaderAst::AliasValueExpression& node) override;
void Visit(ShaderAst::AssignExpression& node) override;
void Visit(ShaderAst::BinaryExpression& node) override;
void Visit(ShaderAst::CallFunctionExpression& node) override;
void Visit(ShaderAst::CastExpression& node) override;
void Visit(ShaderAst::ConditionalExpression& node) override;
void Visit(ShaderAst::ConstantValueExpression& node) override;
void Visit(ShaderAst::ConstantExpression& node) override;
void Visit(ShaderAst::IdentifierExpression& node) override;
void Visit(ShaderAst::IntrinsicExpression& node) override;
void Visit(ShaderAst::StructTypeExpression& node) override;
void Visit(ShaderAst::SwizzleExpression& node) override;

View File

@ -714,10 +714,48 @@ namespace Nz::ShaderAst
ExpressionPtr SanitizeVisitor::Clone(SwizzleExpression& node)
{
auto clone = static_unique_pointer_cast<SwizzleExpression>(AstCloner::Clone(node));
Validate(*clone);
auto expression = CloneExpression(MandatoryExpr(node.expression));
return clone;
const ExpressionType& exprType = GetExpressionType(*expression);
if (m_context->options.removeScalarSwizzling && IsPrimitiveType(exprType))
{
for (std::size_t i = 0; i < node.componentCount; ++i)
{
if (node.components[i] != 0)
throw AstError{ "invalid swizzle" };
}
if (node.componentCount == 1)
return expression; //< ignore this swizzle (a.x == a)
// Use a Cast expression to replace swizzle
expression = CacheResult(std::move(expression)); //< Since we are going to use a value multiple times, cache it if required
PrimitiveType baseType;
if (IsVectorType(exprType))
baseType = std::get<VectorType>(exprType).type;
else
baseType = std::get<PrimitiveType>(exprType);
auto cast = std::make_unique<CastExpression>();
cast->targetType = ExpressionType{ VectorType{ node.componentCount, baseType } };
for (std::size_t j = 0; j < node.componentCount; ++j)
cast->expressions[j] = CloneExpression(expression);
Validate(*cast);
return cast;
}
else
{
auto clone = std::make_unique<SwizzleExpression>();
clone->componentCount = node.componentCount;
clone->components = node.components;
clone->expression = std::move(expression);
Validate(*clone);
return clone;
}
}
ExpressionPtr SanitizeVisitor::Clone(UnaryExpression& node)
@ -1787,7 +1825,7 @@ namespace Nz::ShaderAst
ExpressionPtr SanitizeVisitor::CacheResult(ExpressionPtr expression)
{
// No need to cache LValues (variables/constants) (TODO: Improve this, as constants doesn't need to be cached as well)
// No need to cache LValues (variables/constants) (TODO: Improve this, as constants don't need to be cached as well)
if (GetExpressionCategory(*expression) == ExpressionCategory::LValue)
return expression;
@ -2940,6 +2978,9 @@ namespace Nz::ShaderAst
std::size_t componentCount;
if (IsPrimitiveType(exprType))
{
if (m_context->options.removeScalarSwizzling)
throw AstError{ "internal error" }; //< scalar swizzling should have been removed by then
baseType = std::get<PrimitiveType>(exprType);
componentCount = 1;
}

View File

@ -309,9 +309,10 @@ namespace Nz
Append(type.GetResultingValue());
}
void GlslWriter::Append(const ShaderAst::FunctionType& /*functionType*/)
void GlslWriter::Append(const ShaderAst::FunctionType& functionType)
{
throw std::runtime_error("unexpected function type");
const std::string& targetName = Retrieve(m_currentState->previsitor.functions, functionType.funcIndex).name;
Append(targetName);
}
void GlslWriter::Append(const ShaderAst::IdentifierType& /*identifierType*/)
@ -903,10 +904,9 @@ namespace Nz
void GlslWriter::Visit(ShaderAst::CallFunctionExpression& node)
{
std::size_t functionIndex = std::get<ShaderAst::FunctionType>(GetExpressionType(*node.targetFunction)).funcIndex;
const std::string& targetName = Retrieve(m_currentState->previsitor.functions, functionIndex).name;
node.targetFunction->Visit(*this);
Append(targetName, "(");
Append("(");
for (std::size_t i = 0; i < node.parameters.size(); ++i)
{
if (i != 0)

View File

@ -113,11 +113,12 @@ namespace Nz
};
const States* states = nullptr;
ShaderAst::Module* module;
const ShaderAst::Module* module;
std::size_t currentModuleIndex;
std::stringstream stream;
std::unordered_map<std::size_t, Identifier> aliases;
std::unordered_map<std::size_t, Identifier> constants;
std::unordered_map<std::size_t, Identifier> functions;
std::unordered_map<std::size_t, Identifier> structs;
std::unordered_map<std::size_t, Identifier> variables;
std::vector<std::string> moduleNames;
@ -134,14 +135,13 @@ namespace Nz
m_currentState = nullptr;
});
ShaderAst::ModulePtr sanitizedModule = ShaderAst::Sanitize(module);
state.module = sanitizedModule.get();
state.module = &module;
AppendHeader();
// Register imported modules
m_currentState->currentModuleIndex = 0;
for (const auto& importedModule : sanitizedModule->importedModules)
for (const auto& importedModule : module.importedModules)
{
AppendAttributes(true, LangVersionAttribute{ importedModule.module->metadata->shaderLangVersion });
AppendAttributes(true, UuidAttribute{ importedModule.module->metadata->moduleId });
@ -155,7 +155,7 @@ namespace Nz
}
m_currentState->currentModuleIndex = std::numeric_limits<std::size_t>::max();
sanitizedModule->rootNode->Visit(*this);
module.rootNode->Visit(*this);
return state.stream.str();
}
@ -185,17 +185,22 @@ namespace Nz
void LangWriter::Append(const ShaderAst::ExpressionValue<ShaderAst::ExpressionType>& type)
{
Append(type.GetResultingValue());
assert(type.HasValue());
if (type.IsResultingValue())
Append(type.GetResultingValue());
else
type.GetExpression()->Visit(*this);
}
void LangWriter::Append(const ShaderAst::FunctionType& /*functionType*/)
void LangWriter::Append(const ShaderAst::FunctionType& functionType)
{
throw std::runtime_error("unexpected function type");
const std::string& targetName = Retrieve(m_currentState->functions, functionType.funcIndex).name;
Append(targetName);
}
void LangWriter::Append(const ShaderAst::IdentifierType& /*identifierType*/)
void LangWriter::Append(const ShaderAst::IdentifierType& identifierType)
{
throw std::runtime_error("unexpected identifier type");
Append(identifierType.name);
}
void LangWriter::Append(const ShaderAst::IntrinsicFunctionType& /*functionType*/)
@ -562,6 +567,8 @@ namespace Nz
}
else
attribute.unroll.GetExpression()->Visit(*this);
Append(")");
}
void LangWriter::AppendAttribute(UuidAttribute attribute)
@ -681,6 +688,16 @@ namespace Nz
m_currentState->constants.emplace(constantIndex, std::move(identifier));
}
void LangWriter::RegisterFunction(std::size_t funcIndex, std::string functionName)
{
State::Identifier identifier;
identifier.moduleIndex = m_currentState->currentModuleIndex;
identifier.name = std::move(functionName);
assert(m_currentState->functions.find(funcIndex) == m_currentState->functions.end());
m_currentState->functions.emplace(funcIndex, std::move(identifier));
}
void LangWriter::RegisterStruct(std::size_t structIndex, std::string structName)
{
State::Identifier identifier;
@ -730,9 +747,6 @@ namespace Nz
{
Visit(node.expr, true);
const ShaderAst::ExpressionType& exprType = ResolveAlias(GetExpressionType(*node.expr));
assert(IsStructType(exprType));
for (const std::string& identifier : node.identifiers)
Append(".", identifier);
}
@ -741,9 +755,6 @@ namespace Nz
{
Visit(node.expr, true);
const ShaderAst::ExpressionType& exprType = ResolveAlias(GetExpressionType(*node.expr));
assert(!IsStructType(exprType));
// Array access
Append("[");
@ -838,6 +849,21 @@ namespace Nz
Visit(node.right, true);
}
void LangWriter::Visit(ShaderAst::CallFunctionExpression& node)
{
node.targetFunction->Visit(*this);
Append("(");
for (std::size_t i = 0; i < node.parameters.size(); ++i)
{
if (i != 0)
Append(", ");
node.parameters[i]->Visit(*this);
}
Append(")");
}
void LangWriter::Visit(ShaderAst::CastExpression& node)
{
Append(node.targetType);
@ -880,8 +906,8 @@ namespace Nz
void LangWriter::Visit(ShaderAst::DeclareAliasStatement& node)
{
assert(node.aliasIndex);
RegisterAlias(*node.aliasIndex, node.name);
if (node.aliasIndex)
RegisterAlias(*node.aliasIndex, node.name);
Append("alias ", node.name, " = ");
assert(node.expression);
@ -891,10 +917,13 @@ namespace Nz
void LangWriter::Visit(ShaderAst::DeclareConstStatement& node)
{
assert(node.constIndex);
RegisterConstant(*node.constIndex, node.name);
if (node.constIndex)
RegisterConstant(*node.constIndex, node.name);
Append("const ", node.name);
if (node.type.HasValue())
Append(": ", node.type);
Append("const ", node.name, ": ", node.type);
if (node.expression)
{
Append(" = ");
@ -940,6 +969,11 @@ namespace Nz
AppendIdentifier(m_currentState->constants, node.constantId);
}
void LangWriter::Visit(ShaderAst::IdentifierExpression& node)
{
Append(node.identifier);
}
void LangWriter::Visit(ShaderAst::DeclareExternalStatement& node)
{
AppendLine("external");
@ -956,8 +990,8 @@ namespace Nz
AppendAttributes(false, SetAttribute{ externalVar.bindingSet }, BindingAttribute{ externalVar.bindingIndex });
Append(externalVar.name, ": ", externalVar.type);
assert(externalVar.varIndex);
RegisterVariable(*externalVar.varIndex, externalVar.name);
if (externalVar.varIndex)
RegisterVariable(*externalVar.varIndex, externalVar.name);
}
LeaveScope();
@ -967,6 +1001,9 @@ namespace Nz
{
NazaraAssert(m_currentState, "This function should only be called while processing an AST");
if (node.funcIndex)
RegisterFunction(*node.funcIndex, node.name);
AppendAttributes(true, EntryAttribute{ node.entryStage }, EarlyFragmentTestsAttribute{ node.earlyFragmentTests }, DepthWriteAttribute{ node.depthWrite });
Append("fn ", node.name, "(");
for (std::size_t i = 0; i < node.parameters.size(); ++i)
@ -980,15 +1017,14 @@ namespace Nz
Append(": ");
Append(parameter.type);
assert(parameter.varIndex);
RegisterVariable(*parameter.varIndex, parameter.name);
if (parameter.varIndex)
RegisterVariable(*parameter.varIndex, parameter.name);
}
Append(")");
if (node.returnType.HasValue())
{
const ShaderAst::ExpressionType& returnType = node.returnType.GetResultingValue();
if (!IsNoType(returnType))
Append(" -> ", returnType);
if (!node.returnType.IsResultingValue() || !IsNoType(node.returnType.GetResultingValue()))
Append(" -> ", node.returnType);
}
AppendLine();
@ -1001,10 +1037,13 @@ namespace Nz
void LangWriter::Visit(ShaderAst::DeclareOptionStatement& node)
{
assert(node.optIndex);
RegisterConstant(*node.optIndex, node.optName);
if (node.optIndex)
RegisterConstant(*node.optIndex, node.optName);
Append("option ", node.optName);
if (node.optType.HasValue())
Append(": ", node.optType);
Append("option ", node.optName, ": ", node.optType);
if (node.defaultValue)
{
Append(" = ");
@ -1016,8 +1055,8 @@ namespace Nz
void LangWriter::Visit(ShaderAst::DeclareStructStatement& node)
{
assert(node.structIndex);
RegisterStruct(*node.structIndex, node.description.name);
if (node.structIndex)
RegisterStruct(*node.structIndex, node.description.name);
AppendAttributes(true, LayoutAttribute{ node.description.layout });
Append("struct ");
@ -1041,10 +1080,13 @@ namespace Nz
void LangWriter::Visit(ShaderAst::DeclareVariableStatement& node)
{
assert(node.varIndex);
RegisterVariable(*node.varIndex, node.varName);
if (node.varIndex)
RegisterVariable(*node.varIndex, node.varName);
Append("let ", node.varName);
if (node.varType.HasValue())
Append(": ", node.varType);
Append("let ", node.varName, ": ", node.varType);
if (node.initialExpression)
{
Append(" = ");
@ -1067,8 +1109,8 @@ namespace Nz
void LangWriter::Visit(ShaderAst::ForStatement& node)
{
assert(node.varIndex);
RegisterVariable(*node.varIndex, node.varName);
if (node.varIndex)
RegisterVariable(*node.varIndex, node.varName);
AppendAttributes(true, UnrollAttribute{ node.unroll });
Append("for ", node.varName, " in ");
@ -1089,8 +1131,8 @@ namespace Nz
void LangWriter::Visit(ShaderAst::ForEachStatement& node)
{
assert(node.varIndex);
RegisterVariable(*node.varIndex, node.varName);
if (node.varIndex)
RegisterVariable(*node.varIndex, node.varName);
AppendAttributes(true, UnrollAttribute{ node.unroll });
Append("for ", node.varName, " in ");
@ -1102,6 +1144,8 @@ namespace Nz
void LangWriter::Visit(ShaderAst::ImportStatement& node)
{
Append("import ");
bool first = true;
for (const std::string& path : node.modulePath)
{
@ -1112,6 +1156,8 @@ namespace Nz
first = false;
}
AppendLine(";");
}
void LangWriter::Visit(ShaderAst::IntrinsicExpression& node)

View File

@ -31,6 +31,7 @@ external
)";
Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource);
shaderModule = SanitizeModule(*shaderModule);
SECTION("Nested AccessMember")
{

View File

@ -51,6 +51,7 @@ fn main(input: In) -> FragOut
)";
Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource);
shaderModule = SanitizeModule(*shaderModule);
ExpectGLSL(*shaderModule, R"(
void main()

View File

@ -36,6 +36,7 @@ fn main()
)";
Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource);
shaderModule = SanitizeModule(*shaderModule);
ExpectGLSL(*shaderModule, R"(
void main()
@ -115,6 +116,7 @@ fn main()
)";
Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource);
shaderModule = SanitizeModule(*shaderModule);
ExpectGLSL(*shaderModule, R"(
void main()
@ -189,6 +191,7 @@ fn main()
)";
Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource);
shaderModule = SanitizeModule(*shaderModule);
ExpectGLSL(*shaderModule, R"(
void main()

View File

@ -11,7 +11,7 @@
void ExpectOutput(Nz::ShaderAst::Module& shaderModule, const Nz::ShaderAst::SanitizeVisitor::Options& options, std::string_view expectedOptimizedResult)
{
Nz::ShaderAst::ModulePtr sanitizedShader;
REQUIRE_NOTHROW(sanitizedShader = Nz::ShaderAst::Sanitize(shaderModule, options));
sanitizedShader = SanitizeModule(shaderModule, options);
ExpectNZSL(*sanitizedShader, expectedOptimizedResult);
}

View File

@ -38,6 +38,7 @@ fn main()
)";
Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource);
shaderModule = SanitizeModule(*shaderModule);
ExpectGLSL(*shaderModule, R"(
void main()
@ -112,7 +113,7 @@ fn main()
)";
Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource);
shaderModule = SanitizeModule(*shaderModule);
ExpectGLSL(*shaderModule, R"(
void main()
@ -190,7 +191,7 @@ fn main()
)";
Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource);
shaderModule = SanitizeModule(*shaderModule);
ExpectGLSL(*shaderModule, R"(
void main()
@ -282,7 +283,7 @@ fn main()
)";
Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource);
shaderModule = SanitizeModule(*shaderModule);
ExpectGLSL(*shaderModule, R"(
void main()

View File

@ -2,6 +2,7 @@
#include <Nazara/Core/File.hpp>
#include <Nazara/Core/StringExt.hpp>
#include <Nazara/Shader/DirectoryModuleResolver.hpp>
#include <Nazara/Shader/LangWriter.hpp>
#include <Nazara/Shader/ShaderBuilder.hpp>
#include <Nazara/Shader/ShaderLangParser.hpp>
#include <Nazara/Shader/Ast/SanitizeVisitor.hpp>
@ -74,7 +75,7 @@ fn main(input: InputData) -> OutputData
Nz::ShaderAst::SanitizeVisitor::Options sanitizeOpt;
sanitizeOpt.moduleResolver = directoryModuleResolver;
REQUIRE_NOTHROW(shaderModule = Nz::ShaderAst::Sanitize(*shaderModule, sanitizeOpt));
shaderModule = SanitizeModule(*shaderModule, sanitizeOpt);
ExpectGLSL(*shaderModule, R"(
// Module ad3aed6e-0619-4a26-b5ce-abc2ec0836c4
@ -279,7 +280,7 @@ fn main(input: InputData) -> OutputData
Nz::ShaderAst::SanitizeVisitor::Options sanitizeOpt;
sanitizeOpt.moduleResolver = directoryModuleResolver;
REQUIRE_NOTHROW(shaderModule = Nz::ShaderAst::Sanitize(*shaderModule, sanitizeOpt));
shaderModule = SanitizeModule(*shaderModule, sanitizeOpt);
ExpectGLSL(*shaderModule, R"(
// Module ad3aed6e-0619-4a26-b5ce-abc2ec0836c4

View File

@ -19,7 +19,7 @@ void PropagateConstantAndExpect(std::string_view sourceCode, std::string_view ex
{
Nz::ShaderAst::ModulePtr shaderModule;
REQUIRE_NOTHROW(shaderModule = Nz::ShaderLang::Parse(sourceCode));
REQUIRE_NOTHROW(shaderModule = Nz::ShaderAst::Sanitize(*shaderModule));
shaderModule = SanitizeModule(*shaderModule);
REQUIRE_NOTHROW(shaderModule = Nz::ShaderAst::PropagateConstants(*shaderModule));
ExpectNZSL(*shaderModule, expectedOptimizedResult);
@ -32,7 +32,7 @@ void EliminateUnusedAndExpect(std::string_view sourceCode, std::string_view expe
Nz::ShaderAst::ModulePtr shaderModule;
REQUIRE_NOTHROW(shaderModule = Nz::ShaderLang::Parse(sourceCode));
REQUIRE_NOTHROW(shaderModule = Nz::ShaderAst::Sanitize(*shaderModule));
shaderModule = SanitizeModule(*shaderModule);
REQUIRE_NOTHROW(shaderModule = Nz::ShaderAst::EliminateUnusedPass(*shaderModule, depConfig));
ExpectNZSL(*shaderModule, expectedOptimizedResult);

View File

@ -272,3 +272,31 @@ void ExpectSPIRV(const Nz::ShaderAst::Module& shaderModule, std::string_view exp
}
}
}
Nz::ShaderAst::ModulePtr SanitizeModule(const Nz::ShaderAst::Module& module)
{
Nz::ShaderAst::SanitizeVisitor::Options defaultOptions;
return SanitizeModule(module, defaultOptions);
}
Nz::ShaderAst::ModulePtr SanitizeModule(const Nz::ShaderAst::Module& module, const Nz::ShaderAst::SanitizeVisitor::Options& options)
{
Nz::ShaderAst::ModulePtr shaderModule;
WHEN("We sanitize the shader")
{
REQUIRE_NOTHROW(shaderModule = Nz::ShaderAst::Sanitize(module, options));
}
WHEN("We output NZSL and try to parse it again")
{
Nz::LangWriter langWriter;
std::string outputCode = langWriter.Generate((shaderModule) ? *shaderModule : module);
REQUIRE_NOTHROW(shaderModule = Nz::ShaderAst::Sanitize(*Nz::ShaderLang::Parse(outputCode), options));
}
// Ensure sanitization
if (!shaderModule)
REQUIRE_NOTHROW(shaderModule = Nz::ShaderAst::Sanitize(module, options));
return shaderModule;
}

View File

@ -4,10 +4,14 @@
#define NAZARA_UNITTESTS_SHADER_SHADERUTILS_HPP
#include <Nazara/Shader/Ast/Module.hpp>
#include <Nazara/Shader/Ast/SanitizeVisitor.hpp>
#include <string>
void ExpectGLSL(const Nz::ShaderAst::Module& shader, std::string_view expectedOutput);
void ExpectNZSL(const Nz::ShaderAst::Module& shader, std::string_view expectedOutput);
void ExpectSPIRV(const Nz::ShaderAst::Module& shader, std::string_view expectedOutput, bool outputParameter = false);
Nz::ShaderAst::ModulePtr SanitizeModule(const Nz::ShaderAst::Module& module);
Nz::ShaderAst::ModulePtr SanitizeModule(const Nz::ShaderAst::Module& module, const Nz::ShaderAst::SanitizeVisitor::Options& options);
#endif

View File

@ -25,6 +25,7 @@ fn main()
)";
Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource);
shaderModule = SanitizeModule(*shaderModule);
ExpectGLSL(*shaderModule, R"(
void main()
@ -72,6 +73,7 @@ fn main()
)";
Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource);
shaderModule = SanitizeModule(*shaderModule);
ExpectGLSL(*shaderModule, R"(
void main()
@ -122,6 +124,7 @@ fn main()
)";
Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource);
shaderModule = SanitizeModule(*shaderModule);
ExpectGLSL(*shaderModule, R"(
void main()
@ -168,6 +171,7 @@ fn main()
)";
Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource);
shaderModule = SanitizeModule(*shaderModule);
ExpectGLSL(*shaderModule, R"(
void main()
@ -221,6 +225,7 @@ fn main()
)";
Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource);
shaderModule = SanitizeModule(*shaderModule);
ExpectGLSL(*shaderModule, R"(
void main()
@ -272,6 +277,7 @@ fn main()
)";
Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource);
shaderModule = SanitizeModule(*shaderModule);
ExpectGLSL(*shaderModule, R"(
void main()