Shader: Add support for depth_write and early_fragment_tests attributes (+ FragDepth builtin)

This commit is contained in:
Jérôme Leclercq 2021-06-01 12:32:24 +02:00
parent 465837ff12
commit 16e2f5f819
12 changed files with 456 additions and 136 deletions

View File

@ -8,97 +8,131 @@
#define NAZARA_SHADER_AST_ENUMS_HPP #define NAZARA_SHADER_AST_ENUMS_HPP
#include <Nazara/Prerequisites.hpp> #include <Nazara/Prerequisites.hpp>
#include <Nazara/Core/Flags.hpp>
namespace Nz::ShaderAst namespace Nz
{ {
enum class AssignType namespace ShaderAst
{ {
Simple //< = enum class AssignType
{
Simple //< =
};
enum class AttributeType
{
Binding, //< Binding (external var only) - has argument index
Builtin, //< Builtin (struct member only) - has argument type
DepthWrite, //< Depth write mode (function only) - has argument type
EarlyFragmentTests, //< Entry point (function only) - has argument on/off
Entry, //< Entry point (function only) - has argument type
Layout, //< Struct layout (struct only) - has argument style
Location, //< Location (struct member only) - has argument index
Option, //< Conditional compilation option - has argument expr
};
enum class BinaryType
{
Add, //< +
Subtract, //< -
Multiply, //< *
Divide, //< /
CompEq, //< ==
CompGe, //< >=
CompGt, //< >
CompLe, //< <=
CompLt, //< <
CompNe //< <=
};
enum class BuiltinEntry
{
FragCoord = 1, // gl_FragCoord
FragDepth = 2, // gl_FragDepth
VertexPosition = 0, // gl_Position
};
enum class DepthWriteMode
{
Greater,
Less,
Replace,
Unchanged,
};
enum class ExpressionCategory
{
LValue,
RValue
};
enum class FunctionFlag
{
DoesDiscard,
DoesWriteFragDepth,
Max = DoesWriteFragDepth
};
}
template<>
struct EnumAsFlags<ShaderAst::FunctionFlag>
{
static constexpr ShaderAst::FunctionFlag max = ShaderAst::FunctionFlag::Max;
}; };
enum class AttributeType namespace ShaderAst
{ {
Binding, //< Binding (external var only) - has argument index using FunctionFlags = Flags<FunctionFlag>;
Builtin, //< Builtin (struct member only) - has argument type
Entry, //< Entry point (function only) - has argument type
Layout, //< Struct layout (struct only) - has argument style
Location, //< Location (struct member only) - has argument index
Option, //< Conditional compilation option - has argument expr
};
enum class BinaryType enum class IntrinsicType
{ {
Add, //< + CrossProduct = 0,
Subtract, //< - DotProduct = 1,
Multiply, //< * Length = 3,
Divide, //< / Max = 4,
Min = 5,
SampleTexture = 2,
};
CompEq, //< == enum class MemoryLayout
CompGe, //< >= {
CompGt, //< > Std140
CompLe, //< <= };
CompLt, //< <
CompNe //< <=
};
enum class BuiltinEntry enum class NodeType
{ {
FragCoord = 1, // gl_FragCoord None = -1,
VertexPosition = 0, // gl_Position
};
enum class ExpressionCategory
{
LValue,
RValue
};
enum class IntrinsicType
{
CrossProduct = 0,
DotProduct = 1,
Length = 3,
Max = 4,
Min = 5,
SampleTexture = 2,
};
enum class MemoryLayout
{
Std140
};
enum class NodeType
{
None = -1,
#define NAZARA_SHADERAST_NODE(Node) Node, #define NAZARA_SHADERAST_NODE(Node) Node,
#define NAZARA_SHADERAST_STATEMENT_LAST(Node) Node, Max = Node #define NAZARA_SHADERAST_STATEMENT_LAST(Node) Node, Max = Node
#include <Nazara/Shader/Ast/AstNodeList.hpp> #include <Nazara/Shader/Ast/AstNodeList.hpp>
}; };
enum class PrimitiveType enum class PrimitiveType
{ {
Boolean, //< bool Boolean, //< bool
Float32, //< f32 Float32, //< f32
Int32, //< i32 Int32, //< i32
UInt32, //< ui32 UInt32, //< ui32
}; };
enum class SwizzleComponent enum class SwizzleComponent
{ {
First, First,
Second, Second,
Third, Third,
Fourth Fourth
}; };
enum class UnaryType enum class UnaryType
{ {
LogicalNot, //< !v LogicalNot, //< !v
Minus, //< -v Minus, //< -v
Plus, //< +v Plus, //< +v
}; };
}
} }
#endif // NAZARA_SHADER_ENUMS_HPP #endif // NAZARA_SHADER_ENUMS_HPP

View File

@ -272,6 +272,8 @@ namespace Nz::ShaderAst
ExpressionType type; ExpressionType type;
}; };
std::optional<DepthWriteMode> depthWrite;
std::optional<bool> earlyFragmentTests;
std::optional<ShaderStageType> entryStage; std::optional<ShaderStageType> entryStage;
std::optional<std::size_t> funcIndex; std::optional<std::size_t> funcIndex;
std::optional<std::size_t> varIndex; std::optional<std::size_t> varIndex;

View File

@ -8,6 +8,7 @@
#define NAZARA_SHADERAST_TRANSFORMVISITOR_HPP #define NAZARA_SHADERAST_TRANSFORMVISITOR_HPP
#include <Nazara/Prerequisites.hpp> #include <Nazara/Prerequisites.hpp>
#include <Nazara/Core/Bitset.hpp>
#include <Nazara/Shader/Config.hpp> #include <Nazara/Shader/Config.hpp>
#include <Nazara/Shader/Ast/AstCloner.hpp> #include <Nazara/Shader/Ast/AstCloner.hpp>
#include <unordered_map> #include <unordered_map>
@ -39,6 +40,7 @@ namespace Nz::ShaderAst
}; };
private: private:
struct FunctionData;
struct Identifier; struct Identifier;
const ExpressionType& CheckField(const ExpressionType& structType, const std::string* memberIdentifier, std::size_t remainingMembers, std::size_t* structIndices); const ExpressionType& CheckField(const ExpressionType& structType, const std::string* memberIdentifier, std::size_t remainingMembers, std::size_t* structIndices);
@ -65,6 +67,7 @@ namespace Nz::ShaderAst
StatementPtr Clone(DeclareOptionStatement& node) override; StatementPtr Clone(DeclareOptionStatement& node) override;
StatementPtr Clone(DeclareStructStatement& node) override; StatementPtr Clone(DeclareStructStatement& node) override;
StatementPtr Clone(DeclareVariableStatement& node) override; StatementPtr Clone(DeclareVariableStatement& node) override;
StatementPtr Clone(DiscardStatement& node) override;
StatementPtr Clone(ExpressionStatement& node) override; StatementPtr Clone(ExpressionStatement& node) override;
StatementPtr Clone(MultiStatement& node) override; StatementPtr Clone(MultiStatement& node) override;
@ -78,12 +81,18 @@ namespace Nz::ShaderAst
void PushScope(); void PushScope();
void PopScope(); void PopScope();
std::size_t RegisterFunction(DeclareFunctionStatement* funcDecl); std::size_t DeclareFunction(DeclareFunctionStatement* funcDecl);
void PropagateFunctionFlags(std::size_t funcIndex, FunctionFlags flags, Bitset<>& seen);
FunctionData& RegisterFunction(std::size_t functionIndex);
std::size_t RegisterIntrinsic(std::string name, IntrinsicType type); std::size_t RegisterIntrinsic(std::string name, IntrinsicType type);
std::size_t RegisterOption(std::string name, ExpressionType type); std::size_t RegisterOption(std::string name, ExpressionType type);
std::size_t RegisterStruct(std::string name, StructDescription description); std::size_t RegisterStruct(std::string name, StructDescription description);
std::size_t RegisterVariable(std::string name, ExpressionType type); std::size_t RegisterVariable(std::string name, ExpressionType type);
void ResolveFunctions();
std::size_t ResolveStruct(const ExpressionType& exprType); std::size_t ResolveStruct(const ExpressionType& exprType);
std::size_t ResolveStruct(const IdentifierType& identifierType); std::size_t ResolveStruct(const IdentifierType& identifierType);
std::size_t ResolveStruct(const StructType& structType); std::size_t ResolveStruct(const StructType& structType);
@ -95,6 +104,14 @@ namespace Nz::ShaderAst
void Validate(CallFunctionExpression& node, const DeclareFunctionStatement* referenceDeclaration); void Validate(CallFunctionExpression& node, const DeclareFunctionStatement* referenceDeclaration);
void Validate(IntrinsicExpression& node); void Validate(IntrinsicExpression& node);
struct FunctionData
{
Bitset<> calledByFunctions;
DeclareFunctionStatement* node;
FunctionFlags flags;
bool defined = false;
};
struct Identifier struct Identifier
{ {
enum class Type enum class Type
@ -112,9 +129,8 @@ namespace Nz::ShaderAst
Type type; Type type;
}; };
std::unordered_map<std::string /*functionName*/, std::pair<const DeclareFunctionStatement*, std::size_t>> m_functionDeclarations;
std::vector<Identifier> m_identifiersInScope; std::vector<Identifier> m_identifiersInScope;
std::vector<DeclareFunctionStatement*> m_functions; std::vector<FunctionData> m_functions;
std::vector<IntrinsicType> m_intrinsics; std::vector<IntrinsicType> m_intrinsics;
std::vector<ExpressionType> m_options; std::vector<ExpressionType> m_options;
std::vector<StructDescription> m_structs; std::vector<StructDescription> m_structs;

View File

@ -39,6 +39,8 @@ namespace Nz
private: private:
struct BindingAttribute; struct BindingAttribute;
struct BuiltinAttribute; struct BuiltinAttribute;
struct DepthWriteAttribute;
struct EarlyFragmentTestsAttribute;
struct EntryAttribute; struct EntryAttribute;
struct LayoutAttribute; struct LayoutAttribute;
struct LocationAttribute; struct LocationAttribute;
@ -55,8 +57,10 @@ namespace Nz
template<typename T> void Append(const T& param); template<typename T> void Append(const T& param);
template<typename T1, typename T2, typename... Args> void Append(const T1& firstParam, const T2& secondParam, Args&&... params); template<typename T1, typename T2, typename... Args> void Append(const T1& firstParam, const T2& secondParam, Args&&... params);
template<typename... Args> void AppendAttributes(bool appendLine, Args&&... params); template<typename... Args> void AppendAttributes(bool appendLine, Args&&... params);
void AppendAttribute(BindingAttribute builtin); void AppendAttribute(BindingAttribute binding);
void AppendAttribute(BuiltinAttribute builtin); void AppendAttribute(BuiltinAttribute builtin);
void AppendAttribute(DepthWriteAttribute depthWrite);
void AppendAttribute(EarlyFragmentTestsAttribute earlyFragmentTests);
void AppendAttribute(EntryAttribute entry); void AppendAttribute(EntryAttribute entry);
void AppendAttribute(LayoutAttribute layout); void AppendAttribute(LayoutAttribute layout);
void AppendAttribute(LocationAttribute location); void AppendAttribute(LocationAttribute location);

View File

@ -95,6 +95,7 @@ namespace Nz
std::optional<UInt32> outputStructTypeId; std::optional<UInt32> outputStructTypeId;
std::vector<Input> inputs; std::vector<Input> inputs;
std::vector<Output> outputs; std::vector<Output> outputs;
std::vector<SpirvExecutionMode> executionModes;
}; };
struct FuncData struct FuncData

View File

@ -74,6 +74,8 @@ namespace Nz::ShaderAst
StatementPtr AstCloner::Clone(DeclareFunctionStatement& node) StatementPtr AstCloner::Clone(DeclareFunctionStatement& node)
{ {
auto clone = std::make_unique<DeclareFunctionStatement>(); auto clone = std::make_unique<DeclareFunctionStatement>();
clone->depthWrite = node.depthWrite;
clone->earlyFragmentTests = node.earlyFragmentTests;
clone->entryStage = node.entryStage; clone->entryStage = node.entryStage;
clone->funcIndex = node.funcIndex; clone->funcIndex = node.funcIndex;
clone->name = node.name; clone->name = node.name;

View File

@ -228,6 +228,8 @@ namespace Nz::ShaderAst
{ {
Value(node.name); Value(node.name);
Type(node.returnType); Type(node.returnType);
OptEnum(node.depthWrite);
OptVal(node.earlyFragmentTests);
OptEnum(node.entryStage); OptEnum(node.entryStage);
OptVal(node.funcIndex); OptVal(node.funcIndex);
Value(node.optionName); Value(node.optionName);

View File

@ -30,10 +30,19 @@ namespace Nz::ShaderAst
struct SanitizeVisitor::Context struct SanitizeVisitor::Context
{ {
struct FunctionData
{
std::optional<ShaderStageType> stageType;
Bitset<> calledFunctions;
DeclareFunctionStatement* statement;
FunctionFlags flags;
};
Options options; Options options;
std::array<DeclareFunctionStatement*, ShaderStageTypeCount> entryFunctions = {}; std::array<DeclareFunctionStatement*, ShaderStageTypeCount> entryFunctions = {};
std::unordered_set<std::string> declaredExternalVar; std::unordered_set<std::string> declaredExternalVar;
std::unordered_set<unsigned int> usedBindingIndexes; std::unordered_set<unsigned int> usedBindingIndexes;
FunctionData* currentFunction = nullptr;
}; };
StatementPtr SanitizeVisitor::Sanitize(const StatementPtr& nodePtr, const Options& options, std::string* error) StatementPtr SanitizeVisitor::Sanitize(const StatementPtr& nodePtr, const Options& options, std::string* error)
@ -57,18 +66,15 @@ namespace Nz::ShaderAst
// Collect function name and their types // Collect function name and their types
if (nodePtr->GetType() == NodeType::MultiStatement) if (nodePtr->GetType() == NodeType::MultiStatement)
{ {
std::size_t functionIndex = 0;
const MultiStatement& multiStatement = static_cast<const MultiStatement&>(*nodePtr); const MultiStatement& multiStatement = static_cast<const MultiStatement&>(*nodePtr);
for (const auto& statementPtr : multiStatement.statements) for (auto& statementPtr : multiStatement.statements)
{ {
if (statementPtr->GetType() == NodeType::DeclareFunctionStatement) if (statementPtr->GetType() == NodeType::DeclareFunctionStatement)
{ DeclareFunction(static_cast<DeclareFunctionStatement*>(statementPtr.get()));
const DeclareFunctionStatement& funcDeclaration = static_cast<const DeclareFunctionStatement&>(*statementPtr);
m_functionDeclarations.emplace(funcDeclaration.name, std::make_pair(&funcDeclaration, functionIndex++));
}
} }
} }
else if (nodePtr->GetType() == NodeType::DeclareFunctionStatement)
DeclareFunction(static_cast<DeclareFunctionStatement*>(nodePtr.get()));
try try
{ {
@ -81,6 +87,8 @@ namespace Nz::ShaderAst
*error = err.errMsg; *error = err.errMsg;
} }
ResolveFunctions();
} }
PopScope(); PopScope();
@ -380,7 +388,8 @@ namespace Nz::ShaderAst
ExpressionPtr SanitizeVisitor::Clone(CallFunctionExpression& node) ExpressionPtr SanitizeVisitor::Clone(CallFunctionExpression& node)
{ {
constexpr std::size_t NoFunction = std::numeric_limits<std::size_t>::max(); if (!m_context->currentFunction)
throw AstError{ "function calls must happen inside a function" };
auto clone = std::make_unique<CallFunctionExpression>(); auto clone = std::make_unique<CallFunctionExpression>();
@ -388,7 +397,7 @@ namespace Nz::ShaderAst
for (std::size_t i = 0; i < node.parameters.size(); ++i) for (std::size_t i = 0; i < node.parameters.size(); ++i)
clone->parameters.push_back(CloneExpression(node.parameters[i])); clone->parameters.push_back(CloneExpression(node.parameters[i]));
const DeclareFunctionStatement* referenceFunctionDeclaration; std::size_t targetFuncIndex;
if (std::holds_alternative<std::string>(node.targetFunction)) if (std::holds_alternative<std::string>(node.targetFunction))
{ {
const std::string& functionName = std::get<std::string>(node.targetFunction); const std::string& functionName = std::get<std::string>(node.targetFunction);
@ -417,28 +426,27 @@ namespace Nz::ShaderAst
throw AstError{ "function expected" }; throw AstError{ "function expected" };
clone->targetFunction = identifier->index; clone->targetFunction = identifier->index;
referenceFunctionDeclaration = m_functions[identifier->index]; targetFuncIndex = identifier->index;
} }
} }
else else
{ {
// Identifier not found, maybe the function is declared later // Identifier not found, maybe the function is declared later
auto it = m_functionDeclarations.find(functionName); auto it = std::find_if(m_functions.begin(), m_functions.end(), [&](const auto& funcData) { return funcData.node->name == functionName; });
if (it == m_functionDeclarations.end()) if (it == m_functions.end())
throw AstError{ "function " + functionName + " does not exist" }; throw AstError{ "function " + functionName + " does not exist" };
clone->targetFunction = it->second.second; targetFuncIndex = std::distance(m_functions.begin(), it);
referenceFunctionDeclaration = it->second.first; clone->targetFunction = targetFuncIndex;
} }
} }
else else
{ targetFuncIndex = std::get<std::size_t>(node.targetFunction);
std::size_t funcIndex = std::get<std::size_t>(node.targetFunction);
referenceFunctionDeclaration = m_functions[funcIndex];
}
Validate(*clone, referenceFunctionDeclaration); m_context->currentFunction->calledFunctions.UnboundedSet(targetFuncIndex);
Validate(*clone, m_functions[targetFuncIndex].node);
return clone; return clone;
} }
@ -447,6 +455,19 @@ namespace Nz::ShaderAst
{ {
auto clone = static_unique_pointer_cast<CastExpression>(AstCloner::Clone(node)); auto clone = static_unique_pointer_cast<CastExpression>(AstCloner::Clone(node));
clone->cachedExpressionType = clone->targetType;
clone->targetType = ResolveType(clone->targetType);
//FIXME: Make proper rules
if (IsMatrixType(clone->targetType) && clone->expressions.front())
{
const ExpressionType& exprType = GetExpressionType(*clone->expressions.front());
if (IsMatrixType(exprType) && !clone->expressions[1])
{
return clone;
}
}
auto GetComponentCount = [](const ExpressionType& exprType) -> std::size_t auto GetComponentCount = [](const ExpressionType& exprType) -> std::size_t
{ {
if (IsVectorType(exprType)) if (IsVectorType(exprType))
@ -476,9 +497,6 @@ namespace Nz::ShaderAst
if (componentCount != requiredComponents) if (componentCount != requiredComponents)
throw AstError{ "component count doesn't match required component count" }; throw AstError{ "component count doesn't match required component count" };
clone->targetType = ResolveType(clone->targetType);
clone->cachedExpressionType = clone->targetType;
return clone; return clone;
} }
@ -732,9 +750,20 @@ namespace Nz::ShaderAst
if (node.parameters.size() > 1) if (node.parameters.size() > 1)
throw AstError{ "entry functions can either take one struct parameter or no parameter" }; throw AstError{ "entry functions can either take one struct parameter or no parameter" };
if (stageType != ShaderStageType::Fragment)
{
if (node.depthWrite.has_value())
throw AstError{ "only fragment entry-points can have the depth_write attribute" };
if (node.earlyFragmentTests.has_value())
throw AstError{ "only functions with entry(frag) attribute can have the early_fragments_tests attribute" };
}
} }
auto clone = std::make_unique<DeclareFunctionStatement>(); auto clone = std::make_unique<DeclareFunctionStatement>();
clone->depthWrite = node.depthWrite;
clone->earlyFragmentTests = node.earlyFragmentTests;
clone->entryStage = node.entryStage; clone->entryStage = node.entryStage;
clone->name = node.name; clone->name = node.name;
clone->optionName = node.optionName; clone->optionName = node.optionName;
@ -743,6 +772,11 @@ namespace Nz::ShaderAst
SanitizeIdentifier(clone->name); SanitizeIdentifier(clone->name);
Context::FunctionData tempFuncData;
tempFuncData.stageType = node.entryStage;
m_context->currentFunction = &tempFuncData;
PushScope(); PushScope();
{ {
for (auto& parameter : clone->parameters) for (auto& parameter : clone->parameters)
@ -761,6 +795,14 @@ namespace Nz::ShaderAst
} }
PopScope(); PopScope();
m_context->currentFunction = nullptr;
if (clone->earlyFragmentTests.has_value() && *clone->earlyFragmentTests)
{
//TODO: warning and disable early fragment tests
throw AstError{ "discard is not compatible with early fragment tests" };
}
if (!clone->optionName.empty()) if (!clone->optionName.empty())
{ {
const Identifier* identifier = FindIdentifier(node.optionName); const Identifier* identifier = FindIdentifier(node.optionName);
@ -775,7 +817,23 @@ namespace Nz::ShaderAst
return ShaderBuilder::ConditionalStatement(optionIndex, std::move(clone)); return ShaderBuilder::ConditionalStatement(optionIndex, std::move(clone));
} }
clone->funcIndex = RegisterFunction(clone.get()); auto it = std::find_if(m_functions.begin(), m_functions.end(), [&](const auto& funcData) { return funcData.node == &node; });
assert(it != m_functions.end());
assert(!it->defined);
std::size_t funcIndex = std::distance(m_functions.begin(), it);
clone->funcIndex = funcIndex;
auto& funcData = RegisterFunction(funcIndex);
funcData.flags = tempFuncData.flags;
for (std::size_t i = tempFuncData.calledFunctions.FindFirst(); i != tempFuncData.calledFunctions.npos; i = tempFuncData.calledFunctions.FindNext(i))
{
assert(i < m_functions.size());
auto& targetFunc = m_functions[i];
targetFunc.calledByFunctions.UnboundedSet(funcIndex);
}
return clone; return clone;
} }
@ -840,6 +898,16 @@ namespace Nz::ShaderAst
return clone; return clone;
} }
StatementPtr SanitizeVisitor::Clone(DiscardStatement& node)
{
if (!m_context->currentFunction)
throw AstError{ "discard can only be used inside a function" };
m_context->currentFunction->flags |= FunctionFlag::DoesDiscard;
return AstCloner::Clone(node);
}
StatementPtr SanitizeVisitor::Clone(ExpressionStatement& node) StatementPtr SanitizeVisitor::Clone(ExpressionStatement& node)
{ {
MandatoryExpr(node.expression); MandatoryExpr(node.expression);
@ -889,34 +957,57 @@ namespace Nz::ShaderAst
m_scopeSizes.pop_back(); m_scopeSizes.pop_back();
} }
std::size_t SanitizeVisitor::RegisterFunction(DeclareFunctionStatement* funcDecl) std::size_t SanitizeVisitor::DeclareFunction(DeclareFunctionStatement* funcDecl)
{ {
if (auto* identifier = FindIdentifier(funcDecl->name)) std::size_t functionIndex = m_functions.size();
auto& funcData = m_functions.emplace_back();
funcData.node = funcDecl;
return functionIndex;
}
void SanitizeVisitor::PropagateFunctionFlags(std::size_t funcIndex, FunctionFlags flags, Bitset<>& seen)
{
assert(funcIndex < m_functions.size());
auto& funcData = m_functions[funcIndex];
assert(funcData.defined);
funcData.flags |= flags;
for (std::size_t i = funcData.calledByFunctions.FindFirst(); i != funcData.calledByFunctions.npos; i = funcData.calledByFunctions.FindNext(i))
PropagateFunctionFlags(i, funcData.flags, seen);
}
auto SanitizeVisitor::RegisterFunction(std::size_t functionIndex) -> FunctionData&
{
assert(m_functions.size() >= functionIndex);
auto& funcData = m_functions[functionIndex];
assert(!funcData.defined);
funcData.defined = true;
if (auto* identifier = FindIdentifier(funcData.node->name))
{ {
bool duplicate = true; bool duplicate = true;
// Functions cannot be declared twice, except for entry ones if their stages are different // Functions cannot be declared twice, except for entry ones if their stages are different
if (funcDecl->entryStage && identifier->type == Identifier::Type::Function) if (funcData.node->entryStage && identifier->type == Identifier::Type::Function)
{ {
auto& otherFunction = m_functions[identifier->index]; auto& otherFunction = m_functions[identifier->index];
if (funcDecl->entryStage != otherFunction->entryStage) if (funcData.node->entryStage != otherFunction.node->entryStage)
duplicate = false; duplicate = false;
} }
if (duplicate) if (duplicate)
throw AstError{ funcDecl->name + " is already used" }; throw AstError{ funcData.node->name + " is already used" };
} }
std::size_t functionIndex = m_functions.size();
m_functions.push_back(funcDecl);
m_identifiersInScope.push_back({ m_identifiersInScope.push_back({
funcDecl->name, funcData.node->name,
functionIndex, functionIndex,
Identifier::Type::Function Identifier::Type::Function
}); });
return functionIndex; return funcData;
} }
std::size_t SanitizeVisitor::RegisterIntrinsic(std::string name, IntrinsicType type) std::size_t SanitizeVisitor::RegisterIntrinsic(std::string name, IntrinsicType type)
@ -965,7 +1056,7 @@ namespace Nz::ShaderAst
std::move(name), std::move(name),
structIndex, structIndex,
Identifier::Type::Struct Identifier::Type::Struct
}); });
return structIndex; return structIndex;
} }
@ -988,6 +1079,26 @@ namespace Nz::ShaderAst
return varIndex; return varIndex;
} }
void SanitizeVisitor::ResolveFunctions()
{
// Once every function is known, we can propagate flags
Bitset<> seen;
for (std::size_t funcIndex = 0; funcIndex < m_functions.size(); ++funcIndex)
{
auto& funcData = m_functions[funcIndex];
PropagateFunctionFlags(funcIndex, funcData.flags, seen);
seen.Clear();
}
for (const FunctionData& funcData : m_functions)
{
if (funcData.flags.Test(ShaderAst::FunctionFlag::DoesDiscard) && funcData.node->entryStage && *funcData.node->entryStage != ShaderStageType::Fragment)
throw AstError{ "discard can only be used in the fragment stage" };
}
}
std::size_t SanitizeVisitor::ResolveStruct(const ExpressionType& exprType) std::size_t SanitizeVisitor::ResolveStruct(const ExpressionType& exprType)
{ {
return std::visit([&](auto&& arg) -> std::size_t return std::visit([&](auto&& arg) -> std::size_t

View File

@ -110,6 +110,7 @@ namespace Nz
std::unordered_map<ShaderAst::BuiltinEntry, Builtin> s_builtinMapping = { std::unordered_map<ShaderAst::BuiltinEntry, Builtin> s_builtinMapping = {
{ ShaderAst::BuiltinEntry::FragCoord, { "gl_FragCoord", ShaderStageType::Fragment } }, { ShaderAst::BuiltinEntry::FragCoord, { "gl_FragCoord", ShaderStageType::Fragment } },
{ ShaderAst::BuiltinEntry::FragDepth, { "gl_FragDepth", ShaderStageType::Fragment } },
{ ShaderAst::BuiltinEntry::VertexPosition, { "gl_Position", ShaderStageType::Vertex } } { ShaderAst::BuiltinEntry::VertexPosition, { "gl_Position", ShaderStageType::Vertex } }
}; };
} }
@ -206,6 +207,10 @@ namespace Nz
Append("gl_FragCoord"); Append("gl_FragCoord");
break; break;
case ShaderAst::BuiltinEntry::FragDepth:
Append("gl_FragDepth");
break;
case ShaderAst::BuiltinEntry::VertexPosition: case ShaderAst::BuiltinEntry::VertexPosition:
Append("gl_Position"); Append("gl_Position");
break; break;
@ -500,6 +505,15 @@ namespace Nz
void GlslWriter::HandleEntryPoint(ShaderAst::DeclareFunctionStatement& node) void GlslWriter::HandleEntryPoint(ShaderAst::DeclareFunctionStatement& node)
{ {
if (node.entryStage == ShaderStageType::Fragment && node.earlyFragmentTests && *node.earlyFragmentTests)
{
if ((m_environment.glES && m_environment.glMajorVersion >= 3 && m_environment.glMinorVersion >= 1) || (!m_environment.glES && m_environment.glMajorVersion >= 4 && m_environment.glMinorVersion >= 2) || m_environment.extCallback("GL_ARB_shader_image_load_store"))
{
AppendLine("layout(early_fragment_tests) in;");
AppendLine();
}
}
HandleInOut(); HandleInOut();
AppendLine("void main()"); AppendLine("void main()");
EnterScope(); EnterScope();

View File

@ -41,6 +41,20 @@ namespace Nz
inline bool HasValue() const { return builtin.has_value(); } inline bool HasValue() const { return builtin.has_value(); }
}; };
struct LangWriter::DepthWriteAttribute
{
std::optional<ShaderAst::DepthWriteMode> writeMode;
inline bool HasValue() const { return writeMode.has_value(); }
};
struct LangWriter::EarlyFragmentTestsAttribute
{
std::optional<bool> earlyFragmentTests;
inline bool HasValue() const { return earlyFragmentTests.has_value(); }
};
struct LangWriter::EntryAttribute struct LangWriter::EntryAttribute
{ {
std::optional<ShaderStageType> stageType; std::optional<ShaderStageType> stageType;
@ -216,12 +230,12 @@ namespace Nz
Append(" "); Append(" ");
} }
void LangWriter::AppendAttribute(BindingAttribute builtin) void LangWriter::AppendAttribute(BindingAttribute binding)
{ {
if (!builtin.HasValue()) if (!binding.HasValue())
return; return;
Append("binding(", *builtin.bindingIndex, ")"); Append("binding(", *binding.bindingIndex, ")");
} }
void LangWriter::AppendAttribute(BuiltinAttribute builtin) void LangWriter::AppendAttribute(BuiltinAttribute builtin)
@ -235,11 +249,51 @@ namespace Nz
Append("builtin(fragcoord)"); Append("builtin(fragcoord)");
break; break;
case ShaderAst::BuiltinEntry::FragDepth:
Append("builtin(fragdepth)");
break;
case ShaderAst::BuiltinEntry::VertexPosition: case ShaderAst::BuiltinEntry::VertexPosition:
Append("builtin(position)"); Append("builtin(position)");
break; break;
} }
} }
void LangWriter::AppendAttribute(DepthWriteAttribute depthWrite)
{
if (!depthWrite.HasValue())
return;
switch (*depthWrite.writeMode)
{
case ShaderAst::DepthWriteMode::Greater:
Append("depth_write(greater)");
break;
case ShaderAst::DepthWriteMode::Less:
Append("depth_write(less)");
break;
case ShaderAst::DepthWriteMode::Replace:
Append("depth_write(replace)");
break;
case ShaderAst::DepthWriteMode::Unchanged:
Append("depth_write(unchanged)");
break;
}
}
void LangWriter::AppendAttribute(EarlyFragmentTestsAttribute earlyFragmentTests)
{
if (!earlyFragmentTests.HasValue())
return;
if (*earlyFragmentTests.earlyFragmentTests)
Append("early_fragment_tests(on)");
else
Append("early_fragment_tests(off)");
}
void LangWriter::AppendAttribute(EntryAttribute entry) void LangWriter::AppendAttribute(EntryAttribute entry)
{ {
@ -553,7 +607,7 @@ namespace Nz
std::optional<std::size_t> varIndexOpt = node.varIndex; std::optional<std::size_t> varIndexOpt = node.varIndex;
AppendAttributes(true, EntryAttribute{ node.entryStage }); AppendAttributes(true, EntryAttribute{ node.entryStage }, EarlyFragmentTestsAttribute{ node.earlyFragmentTests }, DepthWriteAttribute{ node.depthWrite });
Append("fn ", node.name, "("); Append("fn ", node.name, "(");
for (std::size_t i = 0; i < node.parameters.size(); ++i) for (std::size_t i = 0; i < node.parameters.size(); ++i)
{ {

View File

@ -11,7 +11,14 @@
namespace Nz::ShaderLang namespace Nz::ShaderLang
{ {
namespace namespace
{ {
std::unordered_map<std::string, ShaderAst::DepthWriteMode> s_depthWriteModes = {
{ "greater", ShaderAst::DepthWriteMode::Greater },
{ "less", ShaderAst::DepthWriteMode::Less },
{ "replace", ShaderAst::DepthWriteMode::Replace },
{ "unchanged", ShaderAst::DepthWriteMode::Unchanged },
};
std::unordered_map<std::string, ShaderAst::PrimitiveType> s_identifierToBasicType = { std::unordered_map<std::string, ShaderAst::PrimitiveType> s_identifierToBasicType = {
{ "bool", ShaderAst::PrimitiveType::Boolean }, { "bool", ShaderAst::PrimitiveType::Boolean },
{ "i32", ShaderAst::PrimitiveType::Int32 }, { "i32", ShaderAst::PrimitiveType::Int32 },
@ -20,12 +27,14 @@ namespace Nz::ShaderLang
}; };
std::unordered_map<std::string, ShaderAst::AttributeType> s_identifierToAttributeType = { std::unordered_map<std::string, ShaderAst::AttributeType> s_identifierToAttributeType = {
{ "binding", ShaderAst::AttributeType::Binding }, { "binding", ShaderAst::AttributeType::Binding },
{ "builtin", ShaderAst::AttributeType::Builtin }, { "builtin", ShaderAst::AttributeType::Builtin },
{ "entry", ShaderAst::AttributeType::Entry }, { "depth_write", ShaderAst::AttributeType::DepthWrite },
{ "layout", ShaderAst::AttributeType::Layout }, { "early_fragment_tests", ShaderAst::AttributeType::EarlyFragmentTests },
{ "location", ShaderAst::AttributeType::Location }, { "entry", ShaderAst::AttributeType::Entry },
{ "opt", ShaderAst::AttributeType::Option }, { "layout", ShaderAst::AttributeType::Layout },
{ "location", ShaderAst::AttributeType::Location },
{ "opt", ShaderAst::AttributeType::Option },
}; };
std::unordered_map<std::string, ShaderStageType> s_entryPoints = { std::unordered_map<std::string, ShaderStageType> s_entryPoints = {
@ -35,6 +44,7 @@ namespace Nz::ShaderLang
std::unordered_map<std::string, ShaderAst::BuiltinEntry> s_builtinMapping = { std::unordered_map<std::string, ShaderAst::BuiltinEntry> s_builtinMapping = {
{ "fragcoord", ShaderAst::BuiltinEntry::FragCoord }, { "fragcoord", ShaderAst::BuiltinEntry::FragCoord },
{ "fragdepth", ShaderAst::BuiltinEntry::FragDepth },
{ "position", ShaderAst::BuiltinEntry::VertexPosition } { "position", ShaderAst::BuiltinEntry::VertexPosition }
}; };
@ -483,11 +493,55 @@ namespace Nz::ShaderLang
for (const auto& [attributeType, arg] : attributes) for (const auto& [attributeType, arg] : attributes)
{ {
switch (attributeType) switch (attributeType)
{ {
case ShaderAst::AttributeType::DepthWrite:
{
if (func->depthWrite)
throw AttributeError{ "attribute depth_write can only be present once" };
if (!std::holds_alternative<std::string>(arg))
throw AttributeError{ "attribute entry requires a string parameter" };
const std::string& argStr = std::get<std::string>(arg);
auto it = s_depthWriteModes.find(argStr);
if (it == s_depthWriteModes.end())
throw AttributeError{ ("invalid parameter " + argStr + " for depth_write attribute").c_str() };
func->depthWrite = it->second;
break;
}
case ShaderAst::AttributeType::EarlyFragmentTests:
{
if (func->earlyFragmentTests)
throw AttributeError{ "attribute early_fragment_tests can only be present once" };
if (std::holds_alternative<std::string>(arg))
{
const std::string& argStr = std::get<std::string>(arg);
if (argStr == "true" || argStr == "on")
func->earlyFragmentTests = true;
else if (argStr == "false" || argStr == "off")
func->earlyFragmentTests = false;
else
throw AttributeError{ "expected boolean value (got " + argStr + ")" };
}
else if (std::holds_alternative<std::monostate>(arg))
{
// No parameter, default to true
func->earlyFragmentTests = true;
}
else
throw AttributeError{ "unexpected value for early_fragment_tests" };
break;
}
case ShaderAst::AttributeType::Entry: case ShaderAst::AttributeType::Entry:
{ {
if (func->entryStage) if (func->entryStage)
throw AttributeError{ "attribute entry must be present once" }; throw AttributeError{ "attribute entry can only be present once" };
if (!std::holds_alternative<std::string>(arg)) if (!std::holds_alternative<std::string>(arg))
throw AttributeError{ "attribute entry requires a string parameter" }; throw AttributeError{ "attribute entry requires a string parameter" };

View File

@ -37,7 +37,8 @@ namespace Nz
std::unordered_map<ShaderAst::BuiltinEntry, Builtin> s_builtinMapping = { std::unordered_map<ShaderAst::BuiltinEntry, Builtin> s_builtinMapping = {
{ ShaderAst::BuiltinEntry::FragCoord, { "FragmentCoordinates", ShaderStageType::Fragment, SpirvBuiltIn::FragCoord } }, { ShaderAst::BuiltinEntry::FragCoord, { "FragmentCoordinates", ShaderStageType::Fragment, SpirvBuiltIn::FragCoord } },
{ ShaderAst::BuiltinEntry::VertexPosition, { "VertexPosition", ShaderStageType::Vertex, SpirvBuiltIn::Position } } { ShaderAst::BuiltinEntry::FragDepth, { "FragmentDepth", ShaderStageType::Fragment, SpirvBuiltIn::FragDepth } },
{ ShaderAst::BuiltinEntry::VertexPosition, { "VertexPosition", ShaderStageType::Vertex, SpirvBuiltIn::Position } }
}; };
class PreVisitor : public ShaderAst::AstRecursiveVisitor class PreVisitor : public ShaderAst::AstRecursiveVisitor
@ -181,6 +182,28 @@ namespace Nz
{ {
using EntryPoint = SpirvAstVisitor::EntryPoint; using EntryPoint = SpirvAstVisitor::EntryPoint;
std::vector<SpirvExecutionMode> executionModes;
if (*entryPointType == ShaderStageType::Fragment)
{
executionModes.push_back(SpirvExecutionMode::OriginUpperLeft);
if (node.earlyFragmentTests && *node.earlyFragmentTests)
executionModes.push_back(SpirvExecutionMode::EarlyFragmentTests);
if (node.depthWrite)
{
executionModes.push_back(SpirvExecutionMode::DepthReplacing);
switch (*node.depthWrite)
{
case ShaderAst::DepthWriteMode::Replace: break;
case ShaderAst::DepthWriteMode::Greater: executionModes.push_back(SpirvExecutionMode::DepthGreater); break;
case ShaderAst::DepthWriteMode::Less: executionModes.push_back(SpirvExecutionMode::DepthLess); break;
case ShaderAst::DepthWriteMode::Unchanged: executionModes.push_back(SpirvExecutionMode::DepthUnchanged); break;
}
}
}
funcData.returnTypeId = m_constantCache.Register(*m_constantCache.BuildType(ShaderAst::NoType{})); funcData.returnTypeId = m_constantCache.Register(*m_constantCache.BuildType(ShaderAst::NoType{}));
funcData.funcTypeId = m_constantCache.Register(*m_constantCache.BuildFunctionType(ShaderAst::NoType{}, {})); funcData.funcTypeId = m_constantCache.Register(*m_constantCache.BuildFunctionType(ShaderAst::NoType{}, {}));
@ -248,7 +271,8 @@ namespace Nz
inputStruct, inputStruct,
outputStructId, outputStructId,
std::move(inputs), std::move(inputs),
std::move(outputs) std::move(outputs),
std::move(executionModes)
}; };
} }
@ -522,7 +546,6 @@ namespace Nz
m_currentState->header.Append(SpirvOp::OpMemoryModel, SpirvAddressingModel::Logical, SpirvMemoryModel::GLSL450); m_currentState->header.Append(SpirvOp::OpMemoryModel, SpirvAddressingModel::Logical, SpirvMemoryModel::GLSL450);
std::optional<UInt32> fragmentFuncId;
for (auto& func : m_currentState->funcs) for (auto& func : m_currentState->funcs)
{ {
m_currentState->debugInfo.Append(SpirvOp::OpName, func.funcId, func.name); m_currentState->debugInfo.Append(SpirvOp::OpName, func.funcId, func.name);
@ -559,15 +582,18 @@ namespace Nz
for (const auto& output : entryPointData.outputs) for (const auto& output : entryPointData.outputs)
appender(output.varId); appender(output.varId);
}); });
if (entryPointData.stageType == ShaderStageType::Fragment)
fragmentFuncId = func.funcId;
} }
} }
if (fragmentFuncId) // Write execution modes
m_currentState->header.Append(SpirvOp::OpExecutionMode, *fragmentFuncId, SpirvExecutionMode::OriginUpperLeft); for (auto& func : m_currentState->funcs)
{
if (func.entryPointData)
{
for (SpirvExecutionMode executionMode : func.entryPointData->executionModes)
m_currentState->header.Append(SpirvOp::OpExecutionMode, func.funcId, executionMode);
}
}
} }
SpirvConstantCache::TypePtr SpirvWriter::BuildFunctionType(const ShaderAst::DeclareFunctionStatement& functionNode) SpirvConstantCache::TypePtr SpirvWriter::BuildFunctionType(const ShaderAst::DeclareFunctionStatement& functionNode)