Shader: Add support for depth_write and early_fragment_tests attributes (+ FragDepth builtin)
This commit is contained in:
@@ -74,6 +74,8 @@ namespace Nz::ShaderAst
|
||||
StatementPtr AstCloner::Clone(DeclareFunctionStatement& node)
|
||||
{
|
||||
auto clone = std::make_unique<DeclareFunctionStatement>();
|
||||
clone->depthWrite = node.depthWrite;
|
||||
clone->earlyFragmentTests = node.earlyFragmentTests;
|
||||
clone->entryStage = node.entryStage;
|
||||
clone->funcIndex = node.funcIndex;
|
||||
clone->name = node.name;
|
||||
|
||||
@@ -228,6 +228,8 @@ namespace Nz::ShaderAst
|
||||
{
|
||||
Value(node.name);
|
||||
Type(node.returnType);
|
||||
OptEnum(node.depthWrite);
|
||||
OptVal(node.earlyFragmentTests);
|
||||
OptEnum(node.entryStage);
|
||||
OptVal(node.funcIndex);
|
||||
Value(node.optionName);
|
||||
|
||||
@@ -30,10 +30,19 @@ namespace Nz::ShaderAst
|
||||
|
||||
struct SanitizeVisitor::Context
|
||||
{
|
||||
struct FunctionData
|
||||
{
|
||||
std::optional<ShaderStageType> stageType;
|
||||
Bitset<> calledFunctions;
|
||||
DeclareFunctionStatement* statement;
|
||||
FunctionFlags flags;
|
||||
};
|
||||
|
||||
Options options;
|
||||
std::array<DeclareFunctionStatement*, ShaderStageTypeCount> entryFunctions = {};
|
||||
std::unordered_set<std::string> declaredExternalVar;
|
||||
std::unordered_set<unsigned int> usedBindingIndexes;
|
||||
FunctionData* currentFunction = nullptr;
|
||||
};
|
||||
|
||||
StatementPtr SanitizeVisitor::Sanitize(const StatementPtr& nodePtr, const Options& options, std::string* error)
|
||||
@@ -57,18 +66,15 @@ namespace Nz::ShaderAst
|
||||
// Collect function name and their types
|
||||
if (nodePtr->GetType() == NodeType::MultiStatement)
|
||||
{
|
||||
std::size_t functionIndex = 0;
|
||||
|
||||
const MultiStatement& multiStatement = static_cast<const MultiStatement&>(*nodePtr);
|
||||
for (const auto& statementPtr : multiStatement.statements)
|
||||
for (auto& statementPtr : multiStatement.statements)
|
||||
{
|
||||
if (statementPtr->GetType() == NodeType::DeclareFunctionStatement)
|
||||
{
|
||||
const DeclareFunctionStatement& funcDeclaration = static_cast<const DeclareFunctionStatement&>(*statementPtr);
|
||||
m_functionDeclarations.emplace(funcDeclaration.name, std::make_pair(&funcDeclaration, functionIndex++));
|
||||
}
|
||||
DeclareFunction(static_cast<DeclareFunctionStatement*>(statementPtr.get()));
|
||||
}
|
||||
}
|
||||
else if (nodePtr->GetType() == NodeType::DeclareFunctionStatement)
|
||||
DeclareFunction(static_cast<DeclareFunctionStatement*>(nodePtr.get()));
|
||||
|
||||
try
|
||||
{
|
||||
@@ -81,6 +87,8 @@ namespace Nz::ShaderAst
|
||||
|
||||
*error = err.errMsg;
|
||||
}
|
||||
|
||||
ResolveFunctions();
|
||||
}
|
||||
PopScope();
|
||||
|
||||
@@ -380,7 +388,8 @@ namespace Nz::ShaderAst
|
||||
|
||||
ExpressionPtr SanitizeVisitor::Clone(CallFunctionExpression& node)
|
||||
{
|
||||
constexpr std::size_t NoFunction = std::numeric_limits<std::size_t>::max();
|
||||
if (!m_context->currentFunction)
|
||||
throw AstError{ "function calls must happen inside a function" };
|
||||
|
||||
auto clone = std::make_unique<CallFunctionExpression>();
|
||||
|
||||
@@ -388,7 +397,7 @@ namespace Nz::ShaderAst
|
||||
for (std::size_t i = 0; i < node.parameters.size(); ++i)
|
||||
clone->parameters.push_back(CloneExpression(node.parameters[i]));
|
||||
|
||||
const DeclareFunctionStatement* referenceFunctionDeclaration;
|
||||
std::size_t targetFuncIndex;
|
||||
if (std::holds_alternative<std::string>(node.targetFunction))
|
||||
{
|
||||
const std::string& functionName = std::get<std::string>(node.targetFunction);
|
||||
@@ -417,28 +426,27 @@ namespace Nz::ShaderAst
|
||||
throw AstError{ "function expected" };
|
||||
|
||||
clone->targetFunction = identifier->index;
|
||||
referenceFunctionDeclaration = m_functions[identifier->index];
|
||||
targetFuncIndex = identifier->index;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// Identifier not found, maybe the function is declared later
|
||||
auto it = m_functionDeclarations.find(functionName);
|
||||
if (it == m_functionDeclarations.end())
|
||||
auto it = std::find_if(m_functions.begin(), m_functions.end(), [&](const auto& funcData) { return funcData.node->name == functionName; });
|
||||
if (it == m_functions.end())
|
||||
throw AstError{ "function " + functionName + " does not exist" };
|
||||
|
||||
clone->targetFunction = it->second.second;
|
||||
targetFuncIndex = std::distance(m_functions.begin(), it);
|
||||
|
||||
referenceFunctionDeclaration = it->second.first;
|
||||
clone->targetFunction = targetFuncIndex;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
std::size_t funcIndex = std::get<std::size_t>(node.targetFunction);
|
||||
referenceFunctionDeclaration = m_functions[funcIndex];
|
||||
}
|
||||
targetFuncIndex = std::get<std::size_t>(node.targetFunction);
|
||||
|
||||
Validate(*clone, referenceFunctionDeclaration);
|
||||
m_context->currentFunction->calledFunctions.UnboundedSet(targetFuncIndex);
|
||||
|
||||
Validate(*clone, m_functions[targetFuncIndex].node);
|
||||
|
||||
return clone;
|
||||
}
|
||||
@@ -447,6 +455,19 @@ namespace Nz::ShaderAst
|
||||
{
|
||||
auto clone = static_unique_pointer_cast<CastExpression>(AstCloner::Clone(node));
|
||||
|
||||
clone->cachedExpressionType = clone->targetType;
|
||||
clone->targetType = ResolveType(clone->targetType);
|
||||
|
||||
//FIXME: Make proper rules
|
||||
if (IsMatrixType(clone->targetType) && clone->expressions.front())
|
||||
{
|
||||
const ExpressionType& exprType = GetExpressionType(*clone->expressions.front());
|
||||
if (IsMatrixType(exprType) && !clone->expressions[1])
|
||||
{
|
||||
return clone;
|
||||
}
|
||||
}
|
||||
|
||||
auto GetComponentCount = [](const ExpressionType& exprType) -> std::size_t
|
||||
{
|
||||
if (IsVectorType(exprType))
|
||||
@@ -476,9 +497,6 @@ namespace Nz::ShaderAst
|
||||
if (componentCount != requiredComponents)
|
||||
throw AstError{ "component count doesn't match required component count" };
|
||||
|
||||
clone->targetType = ResolveType(clone->targetType);
|
||||
clone->cachedExpressionType = clone->targetType;
|
||||
|
||||
return clone;
|
||||
}
|
||||
|
||||
@@ -732,9 +750,20 @@ namespace Nz::ShaderAst
|
||||
|
||||
if (node.parameters.size() > 1)
|
||||
throw AstError{ "entry functions can either take one struct parameter or no parameter" };
|
||||
|
||||
if (stageType != ShaderStageType::Fragment)
|
||||
{
|
||||
if (node.depthWrite.has_value())
|
||||
throw AstError{ "only fragment entry-points can have the depth_write attribute" };
|
||||
|
||||
if (node.earlyFragmentTests.has_value())
|
||||
throw AstError{ "only functions with entry(frag) attribute can have the early_fragments_tests attribute" };
|
||||
}
|
||||
}
|
||||
|
||||
auto clone = std::make_unique<DeclareFunctionStatement>();
|
||||
clone->depthWrite = node.depthWrite;
|
||||
clone->earlyFragmentTests = node.earlyFragmentTests;
|
||||
clone->entryStage = node.entryStage;
|
||||
clone->name = node.name;
|
||||
clone->optionName = node.optionName;
|
||||
@@ -743,6 +772,11 @@ namespace Nz::ShaderAst
|
||||
|
||||
SanitizeIdentifier(clone->name);
|
||||
|
||||
Context::FunctionData tempFuncData;
|
||||
tempFuncData.stageType = node.entryStage;
|
||||
|
||||
m_context->currentFunction = &tempFuncData;
|
||||
|
||||
PushScope();
|
||||
{
|
||||
for (auto& parameter : clone->parameters)
|
||||
@@ -761,6 +795,14 @@ namespace Nz::ShaderAst
|
||||
}
|
||||
PopScope();
|
||||
|
||||
m_context->currentFunction = nullptr;
|
||||
|
||||
if (clone->earlyFragmentTests.has_value() && *clone->earlyFragmentTests)
|
||||
{
|
||||
//TODO: warning and disable early fragment tests
|
||||
throw AstError{ "discard is not compatible with early fragment tests" };
|
||||
}
|
||||
|
||||
if (!clone->optionName.empty())
|
||||
{
|
||||
const Identifier* identifier = FindIdentifier(node.optionName);
|
||||
@@ -775,7 +817,23 @@ namespace Nz::ShaderAst
|
||||
return ShaderBuilder::ConditionalStatement(optionIndex, std::move(clone));
|
||||
}
|
||||
|
||||
clone->funcIndex = RegisterFunction(clone.get());
|
||||
auto it = std::find_if(m_functions.begin(), m_functions.end(), [&](const auto& funcData) { return funcData.node == &node; });
|
||||
assert(it != m_functions.end());
|
||||
assert(!it->defined);
|
||||
|
||||
std::size_t funcIndex = std::distance(m_functions.begin(), it);
|
||||
|
||||
clone->funcIndex = funcIndex;
|
||||
|
||||
auto& funcData = RegisterFunction(funcIndex);
|
||||
funcData.flags = tempFuncData.flags;
|
||||
|
||||
for (std::size_t i = tempFuncData.calledFunctions.FindFirst(); i != tempFuncData.calledFunctions.npos; i = tempFuncData.calledFunctions.FindNext(i))
|
||||
{
|
||||
assert(i < m_functions.size());
|
||||
auto& targetFunc = m_functions[i];
|
||||
targetFunc.calledByFunctions.UnboundedSet(funcIndex);
|
||||
}
|
||||
|
||||
return clone;
|
||||
}
|
||||
@@ -840,6 +898,16 @@ namespace Nz::ShaderAst
|
||||
return clone;
|
||||
}
|
||||
|
||||
StatementPtr SanitizeVisitor::Clone(DiscardStatement& node)
|
||||
{
|
||||
if (!m_context->currentFunction)
|
||||
throw AstError{ "discard can only be used inside a function" };
|
||||
|
||||
m_context->currentFunction->flags |= FunctionFlag::DoesDiscard;
|
||||
|
||||
return AstCloner::Clone(node);
|
||||
}
|
||||
|
||||
StatementPtr SanitizeVisitor::Clone(ExpressionStatement& node)
|
||||
{
|
||||
MandatoryExpr(node.expression);
|
||||
@@ -889,34 +957,57 @@ namespace Nz::ShaderAst
|
||||
m_scopeSizes.pop_back();
|
||||
}
|
||||
|
||||
std::size_t SanitizeVisitor::RegisterFunction(DeclareFunctionStatement* funcDecl)
|
||||
std::size_t SanitizeVisitor::DeclareFunction(DeclareFunctionStatement* funcDecl)
|
||||
{
|
||||
if (auto* identifier = FindIdentifier(funcDecl->name))
|
||||
std::size_t functionIndex = m_functions.size();
|
||||
auto& funcData = m_functions.emplace_back();
|
||||
funcData.node = funcDecl;
|
||||
|
||||
return functionIndex;
|
||||
}
|
||||
|
||||
void SanitizeVisitor::PropagateFunctionFlags(std::size_t funcIndex, FunctionFlags flags, Bitset<>& seen)
|
||||
{
|
||||
assert(funcIndex < m_functions.size());
|
||||
auto& funcData = m_functions[funcIndex];
|
||||
assert(funcData.defined);
|
||||
|
||||
funcData.flags |= flags;
|
||||
|
||||
for (std::size_t i = funcData.calledByFunctions.FindFirst(); i != funcData.calledByFunctions.npos; i = funcData.calledByFunctions.FindNext(i))
|
||||
PropagateFunctionFlags(i, funcData.flags, seen);
|
||||
}
|
||||
|
||||
auto SanitizeVisitor::RegisterFunction(std::size_t functionIndex) -> FunctionData&
|
||||
{
|
||||
assert(m_functions.size() >= functionIndex);
|
||||
auto& funcData = m_functions[functionIndex];
|
||||
assert(!funcData.defined);
|
||||
funcData.defined = true;
|
||||
|
||||
if (auto* identifier = FindIdentifier(funcData.node->name))
|
||||
{
|
||||
bool duplicate = true;
|
||||
|
||||
// Functions cannot be declared twice, except for entry ones if their stages are different
|
||||
if (funcDecl->entryStage && identifier->type == Identifier::Type::Function)
|
||||
if (funcData.node->entryStage && identifier->type == Identifier::Type::Function)
|
||||
{
|
||||
auto& otherFunction = m_functions[identifier->index];
|
||||
if (funcDecl->entryStage != otherFunction->entryStage)
|
||||
if (funcData.node->entryStage != otherFunction.node->entryStage)
|
||||
duplicate = false;
|
||||
}
|
||||
|
||||
if (duplicate)
|
||||
throw AstError{ funcDecl->name + " is already used" };
|
||||
throw AstError{ funcData.node->name + " is already used" };
|
||||
}
|
||||
|
||||
std::size_t functionIndex = m_functions.size();
|
||||
m_functions.push_back(funcDecl);
|
||||
|
||||
m_identifiersInScope.push_back({
|
||||
funcDecl->name,
|
||||
funcData.node->name,
|
||||
functionIndex,
|
||||
Identifier::Type::Function
|
||||
});
|
||||
});
|
||||
|
||||
return functionIndex;
|
||||
return funcData;
|
||||
}
|
||||
|
||||
std::size_t SanitizeVisitor::RegisterIntrinsic(std::string name, IntrinsicType type)
|
||||
@@ -965,7 +1056,7 @@ namespace Nz::ShaderAst
|
||||
std::move(name),
|
||||
structIndex,
|
||||
Identifier::Type::Struct
|
||||
});
|
||||
});
|
||||
|
||||
return structIndex;
|
||||
}
|
||||
@@ -988,6 +1079,26 @@ namespace Nz::ShaderAst
|
||||
return varIndex;
|
||||
}
|
||||
|
||||
void SanitizeVisitor::ResolveFunctions()
|
||||
{
|
||||
// Once every function is known, we can propagate flags
|
||||
|
||||
Bitset<> seen;
|
||||
for (std::size_t funcIndex = 0; funcIndex < m_functions.size(); ++funcIndex)
|
||||
{
|
||||
auto& funcData = m_functions[funcIndex];
|
||||
|
||||
PropagateFunctionFlags(funcIndex, funcData.flags, seen);
|
||||
seen.Clear();
|
||||
}
|
||||
|
||||
for (const FunctionData& funcData : m_functions)
|
||||
{
|
||||
if (funcData.flags.Test(ShaderAst::FunctionFlag::DoesDiscard) && funcData.node->entryStage && *funcData.node->entryStage != ShaderStageType::Fragment)
|
||||
throw AstError{ "discard can only be used in the fragment stage" };
|
||||
}
|
||||
}
|
||||
|
||||
std::size_t SanitizeVisitor::ResolveStruct(const ExpressionType& exprType)
|
||||
{
|
||||
return std::visit([&](auto&& arg) -> std::size_t
|
||||
|
||||
Reference in New Issue
Block a user