Shader/LangWriter: Make LangWriter able to output AST before sanitization as well
This commit is contained in:
@@ -714,10 +714,48 @@ namespace Nz::ShaderAst
|
||||
|
||||
ExpressionPtr SanitizeVisitor::Clone(SwizzleExpression& node)
|
||||
{
|
||||
auto clone = static_unique_pointer_cast<SwizzleExpression>(AstCloner::Clone(node));
|
||||
Validate(*clone);
|
||||
auto expression = CloneExpression(MandatoryExpr(node.expression));
|
||||
|
||||
return clone;
|
||||
const ExpressionType& exprType = GetExpressionType(*expression);
|
||||
if (m_context->options.removeScalarSwizzling && IsPrimitiveType(exprType))
|
||||
{
|
||||
for (std::size_t i = 0; i < node.componentCount; ++i)
|
||||
{
|
||||
if (node.components[i] != 0)
|
||||
throw AstError{ "invalid swizzle" };
|
||||
|
||||
}
|
||||
if (node.componentCount == 1)
|
||||
return expression; //< ignore this swizzle (a.x == a)
|
||||
|
||||
// Use a Cast expression to replace swizzle
|
||||
expression = CacheResult(std::move(expression)); //< Since we are going to use a value multiple times, cache it if required
|
||||
|
||||
PrimitiveType baseType;
|
||||
if (IsVectorType(exprType))
|
||||
baseType = std::get<VectorType>(exprType).type;
|
||||
else
|
||||
baseType = std::get<PrimitiveType>(exprType);
|
||||
|
||||
auto cast = std::make_unique<CastExpression>();
|
||||
cast->targetType = ExpressionType{ VectorType{ node.componentCount, baseType } };
|
||||
for (std::size_t j = 0; j < node.componentCount; ++j)
|
||||
cast->expressions[j] = CloneExpression(expression);
|
||||
|
||||
Validate(*cast);
|
||||
|
||||
return cast;
|
||||
}
|
||||
else
|
||||
{
|
||||
auto clone = std::make_unique<SwizzleExpression>();
|
||||
clone->componentCount = node.componentCount;
|
||||
clone->components = node.components;
|
||||
clone->expression = std::move(expression);
|
||||
Validate(*clone);
|
||||
|
||||
return clone;
|
||||
}
|
||||
}
|
||||
|
||||
ExpressionPtr SanitizeVisitor::Clone(UnaryExpression& node)
|
||||
@@ -1787,7 +1825,7 @@ namespace Nz::ShaderAst
|
||||
|
||||
ExpressionPtr SanitizeVisitor::CacheResult(ExpressionPtr expression)
|
||||
{
|
||||
// No need to cache LValues (variables/constants) (TODO: Improve this, as constants doesn't need to be cached as well)
|
||||
// No need to cache LValues (variables/constants) (TODO: Improve this, as constants don't need to be cached as well)
|
||||
if (GetExpressionCategory(*expression) == ExpressionCategory::LValue)
|
||||
return expression;
|
||||
|
||||
@@ -2940,6 +2978,9 @@ namespace Nz::ShaderAst
|
||||
std::size_t componentCount;
|
||||
if (IsPrimitiveType(exprType))
|
||||
{
|
||||
if (m_context->options.removeScalarSwizzling)
|
||||
throw AstError{ "internal error" }; //< scalar swizzling should have been removed by then
|
||||
|
||||
baseType = std::get<PrimitiveType>(exprType);
|
||||
componentCount = 1;
|
||||
}
|
||||
|
||||
@@ -309,9 +309,10 @@ namespace Nz
|
||||
Append(type.GetResultingValue());
|
||||
}
|
||||
|
||||
void GlslWriter::Append(const ShaderAst::FunctionType& /*functionType*/)
|
||||
void GlslWriter::Append(const ShaderAst::FunctionType& functionType)
|
||||
{
|
||||
throw std::runtime_error("unexpected function type");
|
||||
const std::string& targetName = Retrieve(m_currentState->previsitor.functions, functionType.funcIndex).name;
|
||||
Append(targetName);
|
||||
}
|
||||
|
||||
void GlslWriter::Append(const ShaderAst::IdentifierType& /*identifierType*/)
|
||||
@@ -903,10 +904,9 @@ namespace Nz
|
||||
|
||||
void GlslWriter::Visit(ShaderAst::CallFunctionExpression& node)
|
||||
{
|
||||
std::size_t functionIndex = std::get<ShaderAst::FunctionType>(GetExpressionType(*node.targetFunction)).funcIndex;
|
||||
const std::string& targetName = Retrieve(m_currentState->previsitor.functions, functionIndex).name;
|
||||
node.targetFunction->Visit(*this);
|
||||
|
||||
Append(targetName, "(");
|
||||
Append("(");
|
||||
for (std::size_t i = 0; i < node.parameters.size(); ++i)
|
||||
{
|
||||
if (i != 0)
|
||||
|
||||
@@ -113,11 +113,12 @@ namespace Nz
|
||||
};
|
||||
|
||||
const States* states = nullptr;
|
||||
ShaderAst::Module* module;
|
||||
const ShaderAst::Module* module;
|
||||
std::size_t currentModuleIndex;
|
||||
std::stringstream stream;
|
||||
std::unordered_map<std::size_t, Identifier> aliases;
|
||||
std::unordered_map<std::size_t, Identifier> constants;
|
||||
std::unordered_map<std::size_t, Identifier> functions;
|
||||
std::unordered_map<std::size_t, Identifier> structs;
|
||||
std::unordered_map<std::size_t, Identifier> variables;
|
||||
std::vector<std::string> moduleNames;
|
||||
@@ -134,14 +135,13 @@ namespace Nz
|
||||
m_currentState = nullptr;
|
||||
});
|
||||
|
||||
ShaderAst::ModulePtr sanitizedModule = ShaderAst::Sanitize(module);
|
||||
state.module = sanitizedModule.get();
|
||||
state.module = &module;
|
||||
|
||||
AppendHeader();
|
||||
|
||||
// Register imported modules
|
||||
m_currentState->currentModuleIndex = 0;
|
||||
for (const auto& importedModule : sanitizedModule->importedModules)
|
||||
for (const auto& importedModule : module.importedModules)
|
||||
{
|
||||
AppendAttributes(true, LangVersionAttribute{ importedModule.module->metadata->shaderLangVersion });
|
||||
AppendAttributes(true, UuidAttribute{ importedModule.module->metadata->moduleId });
|
||||
@@ -155,7 +155,7 @@ namespace Nz
|
||||
}
|
||||
|
||||
m_currentState->currentModuleIndex = std::numeric_limits<std::size_t>::max();
|
||||
sanitizedModule->rootNode->Visit(*this);
|
||||
module.rootNode->Visit(*this);
|
||||
|
||||
return state.stream.str();
|
||||
}
|
||||
@@ -185,17 +185,22 @@ namespace Nz
|
||||
|
||||
void LangWriter::Append(const ShaderAst::ExpressionValue<ShaderAst::ExpressionType>& type)
|
||||
{
|
||||
Append(type.GetResultingValue());
|
||||
assert(type.HasValue());
|
||||
if (type.IsResultingValue())
|
||||
Append(type.GetResultingValue());
|
||||
else
|
||||
type.GetExpression()->Visit(*this);
|
||||
}
|
||||
|
||||
void LangWriter::Append(const ShaderAst::FunctionType& /*functionType*/)
|
||||
void LangWriter::Append(const ShaderAst::FunctionType& functionType)
|
||||
{
|
||||
throw std::runtime_error("unexpected function type");
|
||||
const std::string& targetName = Retrieve(m_currentState->functions, functionType.funcIndex).name;
|
||||
Append(targetName);
|
||||
}
|
||||
|
||||
void LangWriter::Append(const ShaderAst::IdentifierType& /*identifierType*/)
|
||||
void LangWriter::Append(const ShaderAst::IdentifierType& identifierType)
|
||||
{
|
||||
throw std::runtime_error("unexpected identifier type");
|
||||
Append(identifierType.name);
|
||||
}
|
||||
|
||||
void LangWriter::Append(const ShaderAst::IntrinsicFunctionType& /*functionType*/)
|
||||
@@ -562,6 +567,8 @@ namespace Nz
|
||||
}
|
||||
else
|
||||
attribute.unroll.GetExpression()->Visit(*this);
|
||||
|
||||
Append(")");
|
||||
}
|
||||
|
||||
void LangWriter::AppendAttribute(UuidAttribute attribute)
|
||||
@@ -681,6 +688,16 @@ namespace Nz
|
||||
m_currentState->constants.emplace(constantIndex, std::move(identifier));
|
||||
}
|
||||
|
||||
void LangWriter::RegisterFunction(std::size_t funcIndex, std::string functionName)
|
||||
{
|
||||
State::Identifier identifier;
|
||||
identifier.moduleIndex = m_currentState->currentModuleIndex;
|
||||
identifier.name = std::move(functionName);
|
||||
|
||||
assert(m_currentState->functions.find(funcIndex) == m_currentState->functions.end());
|
||||
m_currentState->functions.emplace(funcIndex, std::move(identifier));
|
||||
}
|
||||
|
||||
void LangWriter::RegisterStruct(std::size_t structIndex, std::string structName)
|
||||
{
|
||||
State::Identifier identifier;
|
||||
@@ -730,9 +747,6 @@ namespace Nz
|
||||
{
|
||||
Visit(node.expr, true);
|
||||
|
||||
const ShaderAst::ExpressionType& exprType = ResolveAlias(GetExpressionType(*node.expr));
|
||||
assert(IsStructType(exprType));
|
||||
|
||||
for (const std::string& identifier : node.identifiers)
|
||||
Append(".", identifier);
|
||||
}
|
||||
@@ -741,9 +755,6 @@ namespace Nz
|
||||
{
|
||||
Visit(node.expr, true);
|
||||
|
||||
const ShaderAst::ExpressionType& exprType = ResolveAlias(GetExpressionType(*node.expr));
|
||||
assert(!IsStructType(exprType));
|
||||
|
||||
// Array access
|
||||
Append("[");
|
||||
|
||||
@@ -838,6 +849,21 @@ namespace Nz
|
||||
Visit(node.right, true);
|
||||
}
|
||||
|
||||
void LangWriter::Visit(ShaderAst::CallFunctionExpression& node)
|
||||
{
|
||||
node.targetFunction->Visit(*this);
|
||||
|
||||
Append("(");
|
||||
for (std::size_t i = 0; i < node.parameters.size(); ++i)
|
||||
{
|
||||
if (i != 0)
|
||||
Append(", ");
|
||||
|
||||
node.parameters[i]->Visit(*this);
|
||||
}
|
||||
Append(")");
|
||||
}
|
||||
|
||||
void LangWriter::Visit(ShaderAst::CastExpression& node)
|
||||
{
|
||||
Append(node.targetType);
|
||||
@@ -880,8 +906,8 @@ namespace Nz
|
||||
|
||||
void LangWriter::Visit(ShaderAst::DeclareAliasStatement& node)
|
||||
{
|
||||
assert(node.aliasIndex);
|
||||
RegisterAlias(*node.aliasIndex, node.name);
|
||||
if (node.aliasIndex)
|
||||
RegisterAlias(*node.aliasIndex, node.name);
|
||||
|
||||
Append("alias ", node.name, " = ");
|
||||
assert(node.expression);
|
||||
@@ -891,10 +917,13 @@ namespace Nz
|
||||
|
||||
void LangWriter::Visit(ShaderAst::DeclareConstStatement& node)
|
||||
{
|
||||
assert(node.constIndex);
|
||||
RegisterConstant(*node.constIndex, node.name);
|
||||
if (node.constIndex)
|
||||
RegisterConstant(*node.constIndex, node.name);
|
||||
|
||||
Append("const ", node.name);
|
||||
if (node.type.HasValue())
|
||||
Append(": ", node.type);
|
||||
|
||||
Append("const ", node.name, ": ", node.type);
|
||||
if (node.expression)
|
||||
{
|
||||
Append(" = ");
|
||||
@@ -940,6 +969,11 @@ namespace Nz
|
||||
AppendIdentifier(m_currentState->constants, node.constantId);
|
||||
}
|
||||
|
||||
void LangWriter::Visit(ShaderAst::IdentifierExpression& node)
|
||||
{
|
||||
Append(node.identifier);
|
||||
}
|
||||
|
||||
void LangWriter::Visit(ShaderAst::DeclareExternalStatement& node)
|
||||
{
|
||||
AppendLine("external");
|
||||
@@ -956,8 +990,8 @@ namespace Nz
|
||||
AppendAttributes(false, SetAttribute{ externalVar.bindingSet }, BindingAttribute{ externalVar.bindingIndex });
|
||||
Append(externalVar.name, ": ", externalVar.type);
|
||||
|
||||
assert(externalVar.varIndex);
|
||||
RegisterVariable(*externalVar.varIndex, externalVar.name);
|
||||
if (externalVar.varIndex)
|
||||
RegisterVariable(*externalVar.varIndex, externalVar.name);
|
||||
}
|
||||
|
||||
LeaveScope();
|
||||
@@ -967,6 +1001,9 @@ namespace Nz
|
||||
{
|
||||
NazaraAssert(m_currentState, "This function should only be called while processing an AST");
|
||||
|
||||
if (node.funcIndex)
|
||||
RegisterFunction(*node.funcIndex, node.name);
|
||||
|
||||
AppendAttributes(true, EntryAttribute{ node.entryStage }, EarlyFragmentTestsAttribute{ node.earlyFragmentTests }, DepthWriteAttribute{ node.depthWrite });
|
||||
Append("fn ", node.name, "(");
|
||||
for (std::size_t i = 0; i < node.parameters.size(); ++i)
|
||||
@@ -980,15 +1017,14 @@ namespace Nz
|
||||
Append(": ");
|
||||
Append(parameter.type);
|
||||
|
||||
assert(parameter.varIndex);
|
||||
RegisterVariable(*parameter.varIndex, parameter.name);
|
||||
if (parameter.varIndex)
|
||||
RegisterVariable(*parameter.varIndex, parameter.name);
|
||||
}
|
||||
Append(")");
|
||||
if (node.returnType.HasValue())
|
||||
{
|
||||
const ShaderAst::ExpressionType& returnType = node.returnType.GetResultingValue();
|
||||
if (!IsNoType(returnType))
|
||||
Append(" -> ", returnType);
|
||||
if (!node.returnType.IsResultingValue() || !IsNoType(node.returnType.GetResultingValue()))
|
||||
Append(" -> ", node.returnType);
|
||||
}
|
||||
|
||||
AppendLine();
|
||||
@@ -1001,10 +1037,13 @@ namespace Nz
|
||||
|
||||
void LangWriter::Visit(ShaderAst::DeclareOptionStatement& node)
|
||||
{
|
||||
assert(node.optIndex);
|
||||
RegisterConstant(*node.optIndex, node.optName);
|
||||
if (node.optIndex)
|
||||
RegisterConstant(*node.optIndex, node.optName);
|
||||
|
||||
Append("option ", node.optName);
|
||||
if (node.optType.HasValue())
|
||||
Append(": ", node.optType);
|
||||
|
||||
Append("option ", node.optName, ": ", node.optType);
|
||||
if (node.defaultValue)
|
||||
{
|
||||
Append(" = ");
|
||||
@@ -1016,8 +1055,8 @@ namespace Nz
|
||||
|
||||
void LangWriter::Visit(ShaderAst::DeclareStructStatement& node)
|
||||
{
|
||||
assert(node.structIndex);
|
||||
RegisterStruct(*node.structIndex, node.description.name);
|
||||
if (node.structIndex)
|
||||
RegisterStruct(*node.structIndex, node.description.name);
|
||||
|
||||
AppendAttributes(true, LayoutAttribute{ node.description.layout });
|
||||
Append("struct ");
|
||||
@@ -1041,10 +1080,13 @@ namespace Nz
|
||||
|
||||
void LangWriter::Visit(ShaderAst::DeclareVariableStatement& node)
|
||||
{
|
||||
assert(node.varIndex);
|
||||
RegisterVariable(*node.varIndex, node.varName);
|
||||
if (node.varIndex)
|
||||
RegisterVariable(*node.varIndex, node.varName);
|
||||
|
||||
Append("let ", node.varName);
|
||||
if (node.varType.HasValue())
|
||||
Append(": ", node.varType);
|
||||
|
||||
Append("let ", node.varName, ": ", node.varType);
|
||||
if (node.initialExpression)
|
||||
{
|
||||
Append(" = ");
|
||||
@@ -1067,8 +1109,8 @@ namespace Nz
|
||||
|
||||
void LangWriter::Visit(ShaderAst::ForStatement& node)
|
||||
{
|
||||
assert(node.varIndex);
|
||||
RegisterVariable(*node.varIndex, node.varName);
|
||||
if (node.varIndex)
|
||||
RegisterVariable(*node.varIndex, node.varName);
|
||||
|
||||
AppendAttributes(true, UnrollAttribute{ node.unroll });
|
||||
Append("for ", node.varName, " in ");
|
||||
@@ -1089,8 +1131,8 @@ namespace Nz
|
||||
|
||||
void LangWriter::Visit(ShaderAst::ForEachStatement& node)
|
||||
{
|
||||
assert(node.varIndex);
|
||||
RegisterVariable(*node.varIndex, node.varName);
|
||||
if (node.varIndex)
|
||||
RegisterVariable(*node.varIndex, node.varName);
|
||||
|
||||
AppendAttributes(true, UnrollAttribute{ node.unroll });
|
||||
Append("for ", node.varName, " in ");
|
||||
@@ -1102,6 +1144,8 @@ namespace Nz
|
||||
|
||||
void LangWriter::Visit(ShaderAst::ImportStatement& node)
|
||||
{
|
||||
Append("import ");
|
||||
|
||||
bool first = true;
|
||||
for (const std::string& path : node.modulePath)
|
||||
{
|
||||
@@ -1112,6 +1156,8 @@ namespace Nz
|
||||
|
||||
first = false;
|
||||
}
|
||||
|
||||
AppendLine(";");
|
||||
}
|
||||
|
||||
void LangWriter::Visit(ShaderAst::IntrinsicExpression& node)
|
||||
|
||||
Reference in New Issue
Block a user