Shader/LangWriter: Make LangWriter able to output AST before sanitization as well
This commit is contained in:
@@ -95,6 +95,7 @@ namespace Nz
|
|||||||
|
|
||||||
void RegisterAlias(std::size_t aliasIndex, std::string aliasName);
|
void RegisterAlias(std::size_t aliasIndex, std::string aliasName);
|
||||||
void RegisterConstant(std::size_t constantIndex, std::string constantName);
|
void RegisterConstant(std::size_t constantIndex, std::string constantName);
|
||||||
|
void RegisterFunction(std::size_t funcIndex, std::string functionName);
|
||||||
void RegisterStruct(std::size_t structIndex, std::string structName);
|
void RegisterStruct(std::size_t structIndex, std::string structName);
|
||||||
void RegisterVariable(std::size_t varIndex, std::string varName);
|
void RegisterVariable(std::size_t varIndex, std::string varName);
|
||||||
|
|
||||||
@@ -107,10 +108,12 @@ namespace Nz
|
|||||||
void Visit(ShaderAst::AliasValueExpression& node) override;
|
void Visit(ShaderAst::AliasValueExpression& node) override;
|
||||||
void Visit(ShaderAst::AssignExpression& node) override;
|
void Visit(ShaderAst::AssignExpression& node) override;
|
||||||
void Visit(ShaderAst::BinaryExpression& node) override;
|
void Visit(ShaderAst::BinaryExpression& node) override;
|
||||||
|
void Visit(ShaderAst::CallFunctionExpression& node) override;
|
||||||
void Visit(ShaderAst::CastExpression& node) override;
|
void Visit(ShaderAst::CastExpression& node) override;
|
||||||
void Visit(ShaderAst::ConditionalExpression& node) override;
|
void Visit(ShaderAst::ConditionalExpression& node) override;
|
||||||
void Visit(ShaderAst::ConstantValueExpression& node) override;
|
void Visit(ShaderAst::ConstantValueExpression& node) override;
|
||||||
void Visit(ShaderAst::ConstantExpression& node) override;
|
void Visit(ShaderAst::ConstantExpression& node) override;
|
||||||
|
void Visit(ShaderAst::IdentifierExpression& node) override;
|
||||||
void Visit(ShaderAst::IntrinsicExpression& node) override;
|
void Visit(ShaderAst::IntrinsicExpression& node) override;
|
||||||
void Visit(ShaderAst::StructTypeExpression& node) override;
|
void Visit(ShaderAst::StructTypeExpression& node) override;
|
||||||
void Visit(ShaderAst::SwizzleExpression& node) override;
|
void Visit(ShaderAst::SwizzleExpression& node) override;
|
||||||
|
|||||||
@@ -714,10 +714,48 @@ namespace Nz::ShaderAst
|
|||||||
|
|
||||||
ExpressionPtr SanitizeVisitor::Clone(SwizzleExpression& node)
|
ExpressionPtr SanitizeVisitor::Clone(SwizzleExpression& node)
|
||||||
{
|
{
|
||||||
auto clone = static_unique_pointer_cast<SwizzleExpression>(AstCloner::Clone(node));
|
auto expression = CloneExpression(MandatoryExpr(node.expression));
|
||||||
Validate(*clone);
|
|
||||||
|
|
||||||
return clone;
|
const ExpressionType& exprType = GetExpressionType(*expression);
|
||||||
|
if (m_context->options.removeScalarSwizzling && IsPrimitiveType(exprType))
|
||||||
|
{
|
||||||
|
for (std::size_t i = 0; i < node.componentCount; ++i)
|
||||||
|
{
|
||||||
|
if (node.components[i] != 0)
|
||||||
|
throw AstError{ "invalid swizzle" };
|
||||||
|
|
||||||
|
}
|
||||||
|
if (node.componentCount == 1)
|
||||||
|
return expression; //< ignore this swizzle (a.x == a)
|
||||||
|
|
||||||
|
// Use a Cast expression to replace swizzle
|
||||||
|
expression = CacheResult(std::move(expression)); //< Since we are going to use a value multiple times, cache it if required
|
||||||
|
|
||||||
|
PrimitiveType baseType;
|
||||||
|
if (IsVectorType(exprType))
|
||||||
|
baseType = std::get<VectorType>(exprType).type;
|
||||||
|
else
|
||||||
|
baseType = std::get<PrimitiveType>(exprType);
|
||||||
|
|
||||||
|
auto cast = std::make_unique<CastExpression>();
|
||||||
|
cast->targetType = ExpressionType{ VectorType{ node.componentCount, baseType } };
|
||||||
|
for (std::size_t j = 0; j < node.componentCount; ++j)
|
||||||
|
cast->expressions[j] = CloneExpression(expression);
|
||||||
|
|
||||||
|
Validate(*cast);
|
||||||
|
|
||||||
|
return cast;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
auto clone = std::make_unique<SwizzleExpression>();
|
||||||
|
clone->componentCount = node.componentCount;
|
||||||
|
clone->components = node.components;
|
||||||
|
clone->expression = std::move(expression);
|
||||||
|
Validate(*clone);
|
||||||
|
|
||||||
|
return clone;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ExpressionPtr SanitizeVisitor::Clone(UnaryExpression& node)
|
ExpressionPtr SanitizeVisitor::Clone(UnaryExpression& node)
|
||||||
@@ -1787,7 +1825,7 @@ namespace Nz::ShaderAst
|
|||||||
|
|
||||||
ExpressionPtr SanitizeVisitor::CacheResult(ExpressionPtr expression)
|
ExpressionPtr SanitizeVisitor::CacheResult(ExpressionPtr expression)
|
||||||
{
|
{
|
||||||
// No need to cache LValues (variables/constants) (TODO: Improve this, as constants doesn't need to be cached as well)
|
// No need to cache LValues (variables/constants) (TODO: Improve this, as constants don't need to be cached as well)
|
||||||
if (GetExpressionCategory(*expression) == ExpressionCategory::LValue)
|
if (GetExpressionCategory(*expression) == ExpressionCategory::LValue)
|
||||||
return expression;
|
return expression;
|
||||||
|
|
||||||
@@ -2940,6 +2978,9 @@ namespace Nz::ShaderAst
|
|||||||
std::size_t componentCount;
|
std::size_t componentCount;
|
||||||
if (IsPrimitiveType(exprType))
|
if (IsPrimitiveType(exprType))
|
||||||
{
|
{
|
||||||
|
if (m_context->options.removeScalarSwizzling)
|
||||||
|
throw AstError{ "internal error" }; //< scalar swizzling should have been removed by then
|
||||||
|
|
||||||
baseType = std::get<PrimitiveType>(exprType);
|
baseType = std::get<PrimitiveType>(exprType);
|
||||||
componentCount = 1;
|
componentCount = 1;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -309,9 +309,10 @@ namespace Nz
|
|||||||
Append(type.GetResultingValue());
|
Append(type.GetResultingValue());
|
||||||
}
|
}
|
||||||
|
|
||||||
void GlslWriter::Append(const ShaderAst::FunctionType& /*functionType*/)
|
void GlslWriter::Append(const ShaderAst::FunctionType& functionType)
|
||||||
{
|
{
|
||||||
throw std::runtime_error("unexpected function type");
|
const std::string& targetName = Retrieve(m_currentState->previsitor.functions, functionType.funcIndex).name;
|
||||||
|
Append(targetName);
|
||||||
}
|
}
|
||||||
|
|
||||||
void GlslWriter::Append(const ShaderAst::IdentifierType& /*identifierType*/)
|
void GlslWriter::Append(const ShaderAst::IdentifierType& /*identifierType*/)
|
||||||
@@ -903,10 +904,9 @@ namespace Nz
|
|||||||
|
|
||||||
void GlslWriter::Visit(ShaderAst::CallFunctionExpression& node)
|
void GlslWriter::Visit(ShaderAst::CallFunctionExpression& node)
|
||||||
{
|
{
|
||||||
std::size_t functionIndex = std::get<ShaderAst::FunctionType>(GetExpressionType(*node.targetFunction)).funcIndex;
|
node.targetFunction->Visit(*this);
|
||||||
const std::string& targetName = Retrieve(m_currentState->previsitor.functions, functionIndex).name;
|
|
||||||
|
|
||||||
Append(targetName, "(");
|
Append("(");
|
||||||
for (std::size_t i = 0; i < node.parameters.size(); ++i)
|
for (std::size_t i = 0; i < node.parameters.size(); ++i)
|
||||||
{
|
{
|
||||||
if (i != 0)
|
if (i != 0)
|
||||||
|
|||||||
@@ -113,11 +113,12 @@ namespace Nz
|
|||||||
};
|
};
|
||||||
|
|
||||||
const States* states = nullptr;
|
const States* states = nullptr;
|
||||||
ShaderAst::Module* module;
|
const ShaderAst::Module* module;
|
||||||
std::size_t currentModuleIndex;
|
std::size_t currentModuleIndex;
|
||||||
std::stringstream stream;
|
std::stringstream stream;
|
||||||
std::unordered_map<std::size_t, Identifier> aliases;
|
std::unordered_map<std::size_t, Identifier> aliases;
|
||||||
std::unordered_map<std::size_t, Identifier> constants;
|
std::unordered_map<std::size_t, Identifier> constants;
|
||||||
|
std::unordered_map<std::size_t, Identifier> functions;
|
||||||
std::unordered_map<std::size_t, Identifier> structs;
|
std::unordered_map<std::size_t, Identifier> structs;
|
||||||
std::unordered_map<std::size_t, Identifier> variables;
|
std::unordered_map<std::size_t, Identifier> variables;
|
||||||
std::vector<std::string> moduleNames;
|
std::vector<std::string> moduleNames;
|
||||||
@@ -134,14 +135,13 @@ namespace Nz
|
|||||||
m_currentState = nullptr;
|
m_currentState = nullptr;
|
||||||
});
|
});
|
||||||
|
|
||||||
ShaderAst::ModulePtr sanitizedModule = ShaderAst::Sanitize(module);
|
state.module = &module;
|
||||||
state.module = sanitizedModule.get();
|
|
||||||
|
|
||||||
AppendHeader();
|
AppendHeader();
|
||||||
|
|
||||||
// Register imported modules
|
// Register imported modules
|
||||||
m_currentState->currentModuleIndex = 0;
|
m_currentState->currentModuleIndex = 0;
|
||||||
for (const auto& importedModule : sanitizedModule->importedModules)
|
for (const auto& importedModule : module.importedModules)
|
||||||
{
|
{
|
||||||
AppendAttributes(true, LangVersionAttribute{ importedModule.module->metadata->shaderLangVersion });
|
AppendAttributes(true, LangVersionAttribute{ importedModule.module->metadata->shaderLangVersion });
|
||||||
AppendAttributes(true, UuidAttribute{ importedModule.module->metadata->moduleId });
|
AppendAttributes(true, UuidAttribute{ importedModule.module->metadata->moduleId });
|
||||||
@@ -155,7 +155,7 @@ namespace Nz
|
|||||||
}
|
}
|
||||||
|
|
||||||
m_currentState->currentModuleIndex = std::numeric_limits<std::size_t>::max();
|
m_currentState->currentModuleIndex = std::numeric_limits<std::size_t>::max();
|
||||||
sanitizedModule->rootNode->Visit(*this);
|
module.rootNode->Visit(*this);
|
||||||
|
|
||||||
return state.stream.str();
|
return state.stream.str();
|
||||||
}
|
}
|
||||||
@@ -185,17 +185,22 @@ namespace Nz
|
|||||||
|
|
||||||
void LangWriter::Append(const ShaderAst::ExpressionValue<ShaderAst::ExpressionType>& type)
|
void LangWriter::Append(const ShaderAst::ExpressionValue<ShaderAst::ExpressionType>& type)
|
||||||
{
|
{
|
||||||
Append(type.GetResultingValue());
|
assert(type.HasValue());
|
||||||
|
if (type.IsResultingValue())
|
||||||
|
Append(type.GetResultingValue());
|
||||||
|
else
|
||||||
|
type.GetExpression()->Visit(*this);
|
||||||
}
|
}
|
||||||
|
|
||||||
void LangWriter::Append(const ShaderAst::FunctionType& /*functionType*/)
|
void LangWriter::Append(const ShaderAst::FunctionType& functionType)
|
||||||
{
|
{
|
||||||
throw std::runtime_error("unexpected function type");
|
const std::string& targetName = Retrieve(m_currentState->functions, functionType.funcIndex).name;
|
||||||
|
Append(targetName);
|
||||||
}
|
}
|
||||||
|
|
||||||
void LangWriter::Append(const ShaderAst::IdentifierType& /*identifierType*/)
|
void LangWriter::Append(const ShaderAst::IdentifierType& identifierType)
|
||||||
{
|
{
|
||||||
throw std::runtime_error("unexpected identifier type");
|
Append(identifierType.name);
|
||||||
}
|
}
|
||||||
|
|
||||||
void LangWriter::Append(const ShaderAst::IntrinsicFunctionType& /*functionType*/)
|
void LangWriter::Append(const ShaderAst::IntrinsicFunctionType& /*functionType*/)
|
||||||
@@ -562,6 +567,8 @@ namespace Nz
|
|||||||
}
|
}
|
||||||
else
|
else
|
||||||
attribute.unroll.GetExpression()->Visit(*this);
|
attribute.unroll.GetExpression()->Visit(*this);
|
||||||
|
|
||||||
|
Append(")");
|
||||||
}
|
}
|
||||||
|
|
||||||
void LangWriter::AppendAttribute(UuidAttribute attribute)
|
void LangWriter::AppendAttribute(UuidAttribute attribute)
|
||||||
@@ -681,6 +688,16 @@ namespace Nz
|
|||||||
m_currentState->constants.emplace(constantIndex, std::move(identifier));
|
m_currentState->constants.emplace(constantIndex, std::move(identifier));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void LangWriter::RegisterFunction(std::size_t funcIndex, std::string functionName)
|
||||||
|
{
|
||||||
|
State::Identifier identifier;
|
||||||
|
identifier.moduleIndex = m_currentState->currentModuleIndex;
|
||||||
|
identifier.name = std::move(functionName);
|
||||||
|
|
||||||
|
assert(m_currentState->functions.find(funcIndex) == m_currentState->functions.end());
|
||||||
|
m_currentState->functions.emplace(funcIndex, std::move(identifier));
|
||||||
|
}
|
||||||
|
|
||||||
void LangWriter::RegisterStruct(std::size_t structIndex, std::string structName)
|
void LangWriter::RegisterStruct(std::size_t structIndex, std::string structName)
|
||||||
{
|
{
|
||||||
State::Identifier identifier;
|
State::Identifier identifier;
|
||||||
@@ -730,9 +747,6 @@ namespace Nz
|
|||||||
{
|
{
|
||||||
Visit(node.expr, true);
|
Visit(node.expr, true);
|
||||||
|
|
||||||
const ShaderAst::ExpressionType& exprType = ResolveAlias(GetExpressionType(*node.expr));
|
|
||||||
assert(IsStructType(exprType));
|
|
||||||
|
|
||||||
for (const std::string& identifier : node.identifiers)
|
for (const std::string& identifier : node.identifiers)
|
||||||
Append(".", identifier);
|
Append(".", identifier);
|
||||||
}
|
}
|
||||||
@@ -741,9 +755,6 @@ namespace Nz
|
|||||||
{
|
{
|
||||||
Visit(node.expr, true);
|
Visit(node.expr, true);
|
||||||
|
|
||||||
const ShaderAst::ExpressionType& exprType = ResolveAlias(GetExpressionType(*node.expr));
|
|
||||||
assert(!IsStructType(exprType));
|
|
||||||
|
|
||||||
// Array access
|
// Array access
|
||||||
Append("[");
|
Append("[");
|
||||||
|
|
||||||
@@ -838,6 +849,21 @@ namespace Nz
|
|||||||
Visit(node.right, true);
|
Visit(node.right, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void LangWriter::Visit(ShaderAst::CallFunctionExpression& node)
|
||||||
|
{
|
||||||
|
node.targetFunction->Visit(*this);
|
||||||
|
|
||||||
|
Append("(");
|
||||||
|
for (std::size_t i = 0; i < node.parameters.size(); ++i)
|
||||||
|
{
|
||||||
|
if (i != 0)
|
||||||
|
Append(", ");
|
||||||
|
|
||||||
|
node.parameters[i]->Visit(*this);
|
||||||
|
}
|
||||||
|
Append(")");
|
||||||
|
}
|
||||||
|
|
||||||
void LangWriter::Visit(ShaderAst::CastExpression& node)
|
void LangWriter::Visit(ShaderAst::CastExpression& node)
|
||||||
{
|
{
|
||||||
Append(node.targetType);
|
Append(node.targetType);
|
||||||
@@ -880,8 +906,8 @@ namespace Nz
|
|||||||
|
|
||||||
void LangWriter::Visit(ShaderAst::DeclareAliasStatement& node)
|
void LangWriter::Visit(ShaderAst::DeclareAliasStatement& node)
|
||||||
{
|
{
|
||||||
assert(node.aliasIndex);
|
if (node.aliasIndex)
|
||||||
RegisterAlias(*node.aliasIndex, node.name);
|
RegisterAlias(*node.aliasIndex, node.name);
|
||||||
|
|
||||||
Append("alias ", node.name, " = ");
|
Append("alias ", node.name, " = ");
|
||||||
assert(node.expression);
|
assert(node.expression);
|
||||||
@@ -891,10 +917,13 @@ namespace Nz
|
|||||||
|
|
||||||
void LangWriter::Visit(ShaderAst::DeclareConstStatement& node)
|
void LangWriter::Visit(ShaderAst::DeclareConstStatement& node)
|
||||||
{
|
{
|
||||||
assert(node.constIndex);
|
if (node.constIndex)
|
||||||
RegisterConstant(*node.constIndex, node.name);
|
RegisterConstant(*node.constIndex, node.name);
|
||||||
|
|
||||||
|
Append("const ", node.name);
|
||||||
|
if (node.type.HasValue())
|
||||||
|
Append(": ", node.type);
|
||||||
|
|
||||||
Append("const ", node.name, ": ", node.type);
|
|
||||||
if (node.expression)
|
if (node.expression)
|
||||||
{
|
{
|
||||||
Append(" = ");
|
Append(" = ");
|
||||||
@@ -940,6 +969,11 @@ namespace Nz
|
|||||||
AppendIdentifier(m_currentState->constants, node.constantId);
|
AppendIdentifier(m_currentState->constants, node.constantId);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void LangWriter::Visit(ShaderAst::IdentifierExpression& node)
|
||||||
|
{
|
||||||
|
Append(node.identifier);
|
||||||
|
}
|
||||||
|
|
||||||
void LangWriter::Visit(ShaderAst::DeclareExternalStatement& node)
|
void LangWriter::Visit(ShaderAst::DeclareExternalStatement& node)
|
||||||
{
|
{
|
||||||
AppendLine("external");
|
AppendLine("external");
|
||||||
@@ -956,8 +990,8 @@ namespace Nz
|
|||||||
AppendAttributes(false, SetAttribute{ externalVar.bindingSet }, BindingAttribute{ externalVar.bindingIndex });
|
AppendAttributes(false, SetAttribute{ externalVar.bindingSet }, BindingAttribute{ externalVar.bindingIndex });
|
||||||
Append(externalVar.name, ": ", externalVar.type);
|
Append(externalVar.name, ": ", externalVar.type);
|
||||||
|
|
||||||
assert(externalVar.varIndex);
|
if (externalVar.varIndex)
|
||||||
RegisterVariable(*externalVar.varIndex, externalVar.name);
|
RegisterVariable(*externalVar.varIndex, externalVar.name);
|
||||||
}
|
}
|
||||||
|
|
||||||
LeaveScope();
|
LeaveScope();
|
||||||
@@ -967,6 +1001,9 @@ namespace Nz
|
|||||||
{
|
{
|
||||||
NazaraAssert(m_currentState, "This function should only be called while processing an AST");
|
NazaraAssert(m_currentState, "This function should only be called while processing an AST");
|
||||||
|
|
||||||
|
if (node.funcIndex)
|
||||||
|
RegisterFunction(*node.funcIndex, node.name);
|
||||||
|
|
||||||
AppendAttributes(true, EntryAttribute{ node.entryStage }, EarlyFragmentTestsAttribute{ node.earlyFragmentTests }, DepthWriteAttribute{ node.depthWrite });
|
AppendAttributes(true, EntryAttribute{ node.entryStage }, EarlyFragmentTestsAttribute{ node.earlyFragmentTests }, DepthWriteAttribute{ node.depthWrite });
|
||||||
Append("fn ", node.name, "(");
|
Append("fn ", node.name, "(");
|
||||||
for (std::size_t i = 0; i < node.parameters.size(); ++i)
|
for (std::size_t i = 0; i < node.parameters.size(); ++i)
|
||||||
@@ -980,15 +1017,14 @@ namespace Nz
|
|||||||
Append(": ");
|
Append(": ");
|
||||||
Append(parameter.type);
|
Append(parameter.type);
|
||||||
|
|
||||||
assert(parameter.varIndex);
|
if (parameter.varIndex)
|
||||||
RegisterVariable(*parameter.varIndex, parameter.name);
|
RegisterVariable(*parameter.varIndex, parameter.name);
|
||||||
}
|
}
|
||||||
Append(")");
|
Append(")");
|
||||||
if (node.returnType.HasValue())
|
if (node.returnType.HasValue())
|
||||||
{
|
{
|
||||||
const ShaderAst::ExpressionType& returnType = node.returnType.GetResultingValue();
|
if (!node.returnType.IsResultingValue() || !IsNoType(node.returnType.GetResultingValue()))
|
||||||
if (!IsNoType(returnType))
|
Append(" -> ", node.returnType);
|
||||||
Append(" -> ", returnType);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
AppendLine();
|
AppendLine();
|
||||||
@@ -1001,10 +1037,13 @@ namespace Nz
|
|||||||
|
|
||||||
void LangWriter::Visit(ShaderAst::DeclareOptionStatement& node)
|
void LangWriter::Visit(ShaderAst::DeclareOptionStatement& node)
|
||||||
{
|
{
|
||||||
assert(node.optIndex);
|
if (node.optIndex)
|
||||||
RegisterConstant(*node.optIndex, node.optName);
|
RegisterConstant(*node.optIndex, node.optName);
|
||||||
|
|
||||||
|
Append("option ", node.optName);
|
||||||
|
if (node.optType.HasValue())
|
||||||
|
Append(": ", node.optType);
|
||||||
|
|
||||||
Append("option ", node.optName, ": ", node.optType);
|
|
||||||
if (node.defaultValue)
|
if (node.defaultValue)
|
||||||
{
|
{
|
||||||
Append(" = ");
|
Append(" = ");
|
||||||
@@ -1016,8 +1055,8 @@ namespace Nz
|
|||||||
|
|
||||||
void LangWriter::Visit(ShaderAst::DeclareStructStatement& node)
|
void LangWriter::Visit(ShaderAst::DeclareStructStatement& node)
|
||||||
{
|
{
|
||||||
assert(node.structIndex);
|
if (node.structIndex)
|
||||||
RegisterStruct(*node.structIndex, node.description.name);
|
RegisterStruct(*node.structIndex, node.description.name);
|
||||||
|
|
||||||
AppendAttributes(true, LayoutAttribute{ node.description.layout });
|
AppendAttributes(true, LayoutAttribute{ node.description.layout });
|
||||||
Append("struct ");
|
Append("struct ");
|
||||||
@@ -1041,10 +1080,13 @@ namespace Nz
|
|||||||
|
|
||||||
void LangWriter::Visit(ShaderAst::DeclareVariableStatement& node)
|
void LangWriter::Visit(ShaderAst::DeclareVariableStatement& node)
|
||||||
{
|
{
|
||||||
assert(node.varIndex);
|
if (node.varIndex)
|
||||||
RegisterVariable(*node.varIndex, node.varName);
|
RegisterVariable(*node.varIndex, node.varName);
|
||||||
|
|
||||||
|
Append("let ", node.varName);
|
||||||
|
if (node.varType.HasValue())
|
||||||
|
Append(": ", node.varType);
|
||||||
|
|
||||||
Append("let ", node.varName, ": ", node.varType);
|
|
||||||
if (node.initialExpression)
|
if (node.initialExpression)
|
||||||
{
|
{
|
||||||
Append(" = ");
|
Append(" = ");
|
||||||
@@ -1067,8 +1109,8 @@ namespace Nz
|
|||||||
|
|
||||||
void LangWriter::Visit(ShaderAst::ForStatement& node)
|
void LangWriter::Visit(ShaderAst::ForStatement& node)
|
||||||
{
|
{
|
||||||
assert(node.varIndex);
|
if (node.varIndex)
|
||||||
RegisterVariable(*node.varIndex, node.varName);
|
RegisterVariable(*node.varIndex, node.varName);
|
||||||
|
|
||||||
AppendAttributes(true, UnrollAttribute{ node.unroll });
|
AppendAttributes(true, UnrollAttribute{ node.unroll });
|
||||||
Append("for ", node.varName, " in ");
|
Append("for ", node.varName, " in ");
|
||||||
@@ -1089,8 +1131,8 @@ namespace Nz
|
|||||||
|
|
||||||
void LangWriter::Visit(ShaderAst::ForEachStatement& node)
|
void LangWriter::Visit(ShaderAst::ForEachStatement& node)
|
||||||
{
|
{
|
||||||
assert(node.varIndex);
|
if (node.varIndex)
|
||||||
RegisterVariable(*node.varIndex, node.varName);
|
RegisterVariable(*node.varIndex, node.varName);
|
||||||
|
|
||||||
AppendAttributes(true, UnrollAttribute{ node.unroll });
|
AppendAttributes(true, UnrollAttribute{ node.unroll });
|
||||||
Append("for ", node.varName, " in ");
|
Append("for ", node.varName, " in ");
|
||||||
@@ -1102,6 +1144,8 @@ namespace Nz
|
|||||||
|
|
||||||
void LangWriter::Visit(ShaderAst::ImportStatement& node)
|
void LangWriter::Visit(ShaderAst::ImportStatement& node)
|
||||||
{
|
{
|
||||||
|
Append("import ");
|
||||||
|
|
||||||
bool first = true;
|
bool first = true;
|
||||||
for (const std::string& path : node.modulePath)
|
for (const std::string& path : node.modulePath)
|
||||||
{
|
{
|
||||||
@@ -1112,6 +1156,8 @@ namespace Nz
|
|||||||
|
|
||||||
first = false;
|
first = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
AppendLine(";");
|
||||||
}
|
}
|
||||||
|
|
||||||
void LangWriter::Visit(ShaderAst::IntrinsicExpression& node)
|
void LangWriter::Visit(ShaderAst::IntrinsicExpression& node)
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ external
|
|||||||
)";
|
)";
|
||||||
|
|
||||||
Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource);
|
Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource);
|
||||||
|
shaderModule = SanitizeModule(*shaderModule);
|
||||||
|
|
||||||
SECTION("Nested AccessMember")
|
SECTION("Nested AccessMember")
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -51,6 +51,7 @@ fn main(input: In) -> FragOut
|
|||||||
)";
|
)";
|
||||||
|
|
||||||
Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource);
|
Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource);
|
||||||
|
shaderModule = SanitizeModule(*shaderModule);
|
||||||
|
|
||||||
ExpectGLSL(*shaderModule, R"(
|
ExpectGLSL(*shaderModule, R"(
|
||||||
void main()
|
void main()
|
||||||
|
|||||||
@@ -36,6 +36,7 @@ fn main()
|
|||||||
)";
|
)";
|
||||||
|
|
||||||
Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource);
|
Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource);
|
||||||
|
shaderModule = SanitizeModule(*shaderModule);
|
||||||
|
|
||||||
ExpectGLSL(*shaderModule, R"(
|
ExpectGLSL(*shaderModule, R"(
|
||||||
void main()
|
void main()
|
||||||
@@ -115,6 +116,7 @@ fn main()
|
|||||||
)";
|
)";
|
||||||
|
|
||||||
Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource);
|
Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource);
|
||||||
|
shaderModule = SanitizeModule(*shaderModule);
|
||||||
|
|
||||||
ExpectGLSL(*shaderModule, R"(
|
ExpectGLSL(*shaderModule, R"(
|
||||||
void main()
|
void main()
|
||||||
@@ -189,6 +191,7 @@ fn main()
|
|||||||
)";
|
)";
|
||||||
|
|
||||||
Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource);
|
Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource);
|
||||||
|
shaderModule = SanitizeModule(*shaderModule);
|
||||||
|
|
||||||
ExpectGLSL(*shaderModule, R"(
|
ExpectGLSL(*shaderModule, R"(
|
||||||
void main()
|
void main()
|
||||||
|
|||||||
@@ -11,7 +11,7 @@
|
|||||||
void ExpectOutput(Nz::ShaderAst::Module& shaderModule, const Nz::ShaderAst::SanitizeVisitor::Options& options, std::string_view expectedOptimizedResult)
|
void ExpectOutput(Nz::ShaderAst::Module& shaderModule, const Nz::ShaderAst::SanitizeVisitor::Options& options, std::string_view expectedOptimizedResult)
|
||||||
{
|
{
|
||||||
Nz::ShaderAst::ModulePtr sanitizedShader;
|
Nz::ShaderAst::ModulePtr sanitizedShader;
|
||||||
REQUIRE_NOTHROW(sanitizedShader = Nz::ShaderAst::Sanitize(shaderModule, options));
|
sanitizedShader = SanitizeModule(shaderModule, options);
|
||||||
|
|
||||||
ExpectNZSL(*sanitizedShader, expectedOptimizedResult);
|
ExpectNZSL(*sanitizedShader, expectedOptimizedResult);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -38,6 +38,7 @@ fn main()
|
|||||||
)";
|
)";
|
||||||
|
|
||||||
Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource);
|
Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource);
|
||||||
|
shaderModule = SanitizeModule(*shaderModule);
|
||||||
|
|
||||||
ExpectGLSL(*shaderModule, R"(
|
ExpectGLSL(*shaderModule, R"(
|
||||||
void main()
|
void main()
|
||||||
@@ -112,7 +113,7 @@ fn main()
|
|||||||
)";
|
)";
|
||||||
|
|
||||||
Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource);
|
Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource);
|
||||||
|
shaderModule = SanitizeModule(*shaderModule);
|
||||||
|
|
||||||
ExpectGLSL(*shaderModule, R"(
|
ExpectGLSL(*shaderModule, R"(
|
||||||
void main()
|
void main()
|
||||||
@@ -190,7 +191,7 @@ fn main()
|
|||||||
)";
|
)";
|
||||||
|
|
||||||
Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource);
|
Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource);
|
||||||
|
shaderModule = SanitizeModule(*shaderModule);
|
||||||
|
|
||||||
ExpectGLSL(*shaderModule, R"(
|
ExpectGLSL(*shaderModule, R"(
|
||||||
void main()
|
void main()
|
||||||
@@ -282,7 +283,7 @@ fn main()
|
|||||||
)";
|
)";
|
||||||
|
|
||||||
Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource);
|
Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource);
|
||||||
|
shaderModule = SanitizeModule(*shaderModule);
|
||||||
|
|
||||||
ExpectGLSL(*shaderModule, R"(
|
ExpectGLSL(*shaderModule, R"(
|
||||||
void main()
|
void main()
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
#include <Nazara/Core/File.hpp>
|
#include <Nazara/Core/File.hpp>
|
||||||
#include <Nazara/Core/StringExt.hpp>
|
#include <Nazara/Core/StringExt.hpp>
|
||||||
#include <Nazara/Shader/DirectoryModuleResolver.hpp>
|
#include <Nazara/Shader/DirectoryModuleResolver.hpp>
|
||||||
|
#include <Nazara/Shader/LangWriter.hpp>
|
||||||
#include <Nazara/Shader/ShaderBuilder.hpp>
|
#include <Nazara/Shader/ShaderBuilder.hpp>
|
||||||
#include <Nazara/Shader/ShaderLangParser.hpp>
|
#include <Nazara/Shader/ShaderLangParser.hpp>
|
||||||
#include <Nazara/Shader/Ast/SanitizeVisitor.hpp>
|
#include <Nazara/Shader/Ast/SanitizeVisitor.hpp>
|
||||||
@@ -74,7 +75,7 @@ fn main(input: InputData) -> OutputData
|
|||||||
Nz::ShaderAst::SanitizeVisitor::Options sanitizeOpt;
|
Nz::ShaderAst::SanitizeVisitor::Options sanitizeOpt;
|
||||||
sanitizeOpt.moduleResolver = directoryModuleResolver;
|
sanitizeOpt.moduleResolver = directoryModuleResolver;
|
||||||
|
|
||||||
REQUIRE_NOTHROW(shaderModule = Nz::ShaderAst::Sanitize(*shaderModule, sanitizeOpt));
|
shaderModule = SanitizeModule(*shaderModule, sanitizeOpt);
|
||||||
|
|
||||||
ExpectGLSL(*shaderModule, R"(
|
ExpectGLSL(*shaderModule, R"(
|
||||||
// Module ad3aed6e-0619-4a26-b5ce-abc2ec0836c4
|
// Module ad3aed6e-0619-4a26-b5ce-abc2ec0836c4
|
||||||
@@ -279,7 +280,7 @@ fn main(input: InputData) -> OutputData
|
|||||||
Nz::ShaderAst::SanitizeVisitor::Options sanitizeOpt;
|
Nz::ShaderAst::SanitizeVisitor::Options sanitizeOpt;
|
||||||
sanitizeOpt.moduleResolver = directoryModuleResolver;
|
sanitizeOpt.moduleResolver = directoryModuleResolver;
|
||||||
|
|
||||||
REQUIRE_NOTHROW(shaderModule = Nz::ShaderAst::Sanitize(*shaderModule, sanitizeOpt));
|
shaderModule = SanitizeModule(*shaderModule, sanitizeOpt);
|
||||||
|
|
||||||
ExpectGLSL(*shaderModule, R"(
|
ExpectGLSL(*shaderModule, R"(
|
||||||
// Module ad3aed6e-0619-4a26-b5ce-abc2ec0836c4
|
// Module ad3aed6e-0619-4a26-b5ce-abc2ec0836c4
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ void PropagateConstantAndExpect(std::string_view sourceCode, std::string_view ex
|
|||||||
{
|
{
|
||||||
Nz::ShaderAst::ModulePtr shaderModule;
|
Nz::ShaderAst::ModulePtr shaderModule;
|
||||||
REQUIRE_NOTHROW(shaderModule = Nz::ShaderLang::Parse(sourceCode));
|
REQUIRE_NOTHROW(shaderModule = Nz::ShaderLang::Parse(sourceCode));
|
||||||
REQUIRE_NOTHROW(shaderModule = Nz::ShaderAst::Sanitize(*shaderModule));
|
shaderModule = SanitizeModule(*shaderModule);
|
||||||
REQUIRE_NOTHROW(shaderModule = Nz::ShaderAst::PropagateConstants(*shaderModule));
|
REQUIRE_NOTHROW(shaderModule = Nz::ShaderAst::PropagateConstants(*shaderModule));
|
||||||
|
|
||||||
ExpectNZSL(*shaderModule, expectedOptimizedResult);
|
ExpectNZSL(*shaderModule, expectedOptimizedResult);
|
||||||
@@ -32,7 +32,7 @@ void EliminateUnusedAndExpect(std::string_view sourceCode, std::string_view expe
|
|||||||
|
|
||||||
Nz::ShaderAst::ModulePtr shaderModule;
|
Nz::ShaderAst::ModulePtr shaderModule;
|
||||||
REQUIRE_NOTHROW(shaderModule = Nz::ShaderLang::Parse(sourceCode));
|
REQUIRE_NOTHROW(shaderModule = Nz::ShaderLang::Parse(sourceCode));
|
||||||
REQUIRE_NOTHROW(shaderModule = Nz::ShaderAst::Sanitize(*shaderModule));
|
shaderModule = SanitizeModule(*shaderModule);
|
||||||
REQUIRE_NOTHROW(shaderModule = Nz::ShaderAst::EliminateUnusedPass(*shaderModule, depConfig));
|
REQUIRE_NOTHROW(shaderModule = Nz::ShaderAst::EliminateUnusedPass(*shaderModule, depConfig));
|
||||||
|
|
||||||
ExpectNZSL(*shaderModule, expectedOptimizedResult);
|
ExpectNZSL(*shaderModule, expectedOptimizedResult);
|
||||||
|
|||||||
@@ -272,3 +272,31 @@ void ExpectSPIRV(const Nz::ShaderAst::Module& shaderModule, std::string_view exp
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Nz::ShaderAst::ModulePtr SanitizeModule(const Nz::ShaderAst::Module& module)
|
||||||
|
{
|
||||||
|
Nz::ShaderAst::SanitizeVisitor::Options defaultOptions;
|
||||||
|
return SanitizeModule(module, defaultOptions);
|
||||||
|
}
|
||||||
|
|
||||||
|
Nz::ShaderAst::ModulePtr SanitizeModule(const Nz::ShaderAst::Module& module, const Nz::ShaderAst::SanitizeVisitor::Options& options)
|
||||||
|
{
|
||||||
|
Nz::ShaderAst::ModulePtr shaderModule;
|
||||||
|
WHEN("We sanitize the shader")
|
||||||
|
{
|
||||||
|
REQUIRE_NOTHROW(shaderModule = Nz::ShaderAst::Sanitize(module, options));
|
||||||
|
}
|
||||||
|
|
||||||
|
WHEN("We output NZSL and try to parse it again")
|
||||||
|
{
|
||||||
|
Nz::LangWriter langWriter;
|
||||||
|
std::string outputCode = langWriter.Generate((shaderModule) ? *shaderModule : module);
|
||||||
|
REQUIRE_NOTHROW(shaderModule = Nz::ShaderAst::Sanitize(*Nz::ShaderLang::Parse(outputCode), options));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure sanitization
|
||||||
|
if (!shaderModule)
|
||||||
|
REQUIRE_NOTHROW(shaderModule = Nz::ShaderAst::Sanitize(module, options));
|
||||||
|
|
||||||
|
return shaderModule;
|
||||||
|
}
|
||||||
|
|||||||
@@ -4,10 +4,14 @@
|
|||||||
#define NAZARA_UNITTESTS_SHADER_SHADERUTILS_HPP
|
#define NAZARA_UNITTESTS_SHADER_SHADERUTILS_HPP
|
||||||
|
|
||||||
#include <Nazara/Shader/Ast/Module.hpp>
|
#include <Nazara/Shader/Ast/Module.hpp>
|
||||||
|
#include <Nazara/Shader/Ast/SanitizeVisitor.hpp>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
void ExpectGLSL(const Nz::ShaderAst::Module& shader, std::string_view expectedOutput);
|
void ExpectGLSL(const Nz::ShaderAst::Module& shader, std::string_view expectedOutput);
|
||||||
void ExpectNZSL(const Nz::ShaderAst::Module& shader, std::string_view expectedOutput);
|
void ExpectNZSL(const Nz::ShaderAst::Module& shader, std::string_view expectedOutput);
|
||||||
void ExpectSPIRV(const Nz::ShaderAst::Module& shader, std::string_view expectedOutput, bool outputParameter = false);
|
void ExpectSPIRV(const Nz::ShaderAst::Module& shader, std::string_view expectedOutput, bool outputParameter = false);
|
||||||
|
|
||||||
|
Nz::ShaderAst::ModulePtr SanitizeModule(const Nz::ShaderAst::Module& module);
|
||||||
|
Nz::ShaderAst::ModulePtr SanitizeModule(const Nz::ShaderAst::Module& module, const Nz::ShaderAst::SanitizeVisitor::Options& options);
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ fn main()
|
|||||||
)";
|
)";
|
||||||
|
|
||||||
Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource);
|
Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource);
|
||||||
|
shaderModule = SanitizeModule(*shaderModule);
|
||||||
|
|
||||||
ExpectGLSL(*shaderModule, R"(
|
ExpectGLSL(*shaderModule, R"(
|
||||||
void main()
|
void main()
|
||||||
@@ -72,6 +73,7 @@ fn main()
|
|||||||
)";
|
)";
|
||||||
|
|
||||||
Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource);
|
Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource);
|
||||||
|
shaderModule = SanitizeModule(*shaderModule);
|
||||||
|
|
||||||
ExpectGLSL(*shaderModule, R"(
|
ExpectGLSL(*shaderModule, R"(
|
||||||
void main()
|
void main()
|
||||||
@@ -122,6 +124,7 @@ fn main()
|
|||||||
)";
|
)";
|
||||||
|
|
||||||
Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource);
|
Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource);
|
||||||
|
shaderModule = SanitizeModule(*shaderModule);
|
||||||
|
|
||||||
ExpectGLSL(*shaderModule, R"(
|
ExpectGLSL(*shaderModule, R"(
|
||||||
void main()
|
void main()
|
||||||
@@ -168,6 +171,7 @@ fn main()
|
|||||||
)";
|
)";
|
||||||
|
|
||||||
Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource);
|
Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource);
|
||||||
|
shaderModule = SanitizeModule(*shaderModule);
|
||||||
|
|
||||||
ExpectGLSL(*shaderModule, R"(
|
ExpectGLSL(*shaderModule, R"(
|
||||||
void main()
|
void main()
|
||||||
@@ -221,6 +225,7 @@ fn main()
|
|||||||
)";
|
)";
|
||||||
|
|
||||||
Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource);
|
Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource);
|
||||||
|
shaderModule = SanitizeModule(*shaderModule);
|
||||||
|
|
||||||
ExpectGLSL(*shaderModule, R"(
|
ExpectGLSL(*shaderModule, R"(
|
||||||
void main()
|
void main()
|
||||||
@@ -272,6 +277,7 @@ fn main()
|
|||||||
)";
|
)";
|
||||||
|
|
||||||
Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource);
|
Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource);
|
||||||
|
shaderModule = SanitizeModule(*shaderModule);
|
||||||
|
|
||||||
ExpectGLSL(*shaderModule, R"(
|
ExpectGLSL(*shaderModule, R"(
|
||||||
void main()
|
void main()
|
||||||
|
|||||||
Reference in New Issue
Block a user