Shader: Attribute can now have expressions as values and struct fields can be conditionally supported

This commit is contained in:
Jérôme Leclercq
2021-07-07 11:41:58 +02:00
parent 749b40cb31
commit f9af35b489
36 changed files with 945 additions and 600 deletions

View File

@@ -9,6 +9,7 @@
#include <Nazara/Prerequisites.hpp>
#include <Nazara/Shader/Config.hpp>
#include <Nazara/Shader/Ast/Attribute.hpp>
#include <Nazara/Shader/Ast/AstExpressionVisitor.hpp>
#include <Nazara/Shader/Ast/AstStatementVisitor.hpp>
#include <vector>
@@ -30,6 +31,7 @@ namespace Nz::ShaderAst
AstCloner& operator=(AstCloner&&) = delete;
protected:
template<typename T> AttributeValue<T> CloneAttribute(const AttributeValue<T>& attribute);
inline ExpressionPtr CloneExpression(const ExpressionPtr& expr);
inline StatementPtr CloneStatement(const StatementPtr& statement);
@@ -44,6 +46,7 @@ namespace Nz::ShaderAst
virtual ExpressionPtr Clone(CallMethodExpression& node);
virtual ExpressionPtr Clone(CastExpression& node);
virtual ExpressionPtr Clone(ConditionalExpression& node);
virtual ExpressionPtr Clone(ConstantIndexExpression& node);
virtual ExpressionPtr Clone(ConstantExpression& node);
virtual ExpressionPtr Clone(IdentifierExpression& node);
virtual ExpressionPtr Clone(IntrinsicExpression& node);

View File

@@ -7,6 +7,21 @@
namespace Nz::ShaderAst
{
template<typename T>
AttributeValue<T> AstCloner::CloneAttribute(const AttributeValue<T>& attribute)
{
if (!attribute.HasValue())
return {};
if (attribute.IsExpression())
return CloneExpression(attribute.GetExpression());
else
{
assert(attribute.IsResultingValue());
return attribute.GetResultingValue();
}
}
ExpressionPtr AstCloner::CloneExpression(const ExpressionPtr& expr)
{
if (!expr)

View File

@@ -35,6 +35,7 @@ NAZARA_SHADERAST_EXPRESSION(CallMethodExpression)
NAZARA_SHADERAST_EXPRESSION(CastExpression)
NAZARA_SHADERAST_EXPRESSION(ConditionalExpression)
NAZARA_SHADERAST_EXPRESSION(ConstantExpression)
NAZARA_SHADERAST_EXPRESSION(ConstantIndexExpression)
NAZARA_SHADERAST_EXPRESSION(IdentifierExpression)
NAZARA_SHADERAST_EXPRESSION(IntrinsicExpression)
NAZARA_SHADERAST_EXPRESSION(SelectOptionExpression)

View File

@@ -18,21 +18,32 @@ namespace Nz::ShaderAst
class NAZARA_SHADER_API AstOptimizer : public AstCloner
{
public:
struct Options;
AstOptimizer() = default;
AstOptimizer(const AstOptimizer&) = delete;
AstOptimizer(AstOptimizer&&) = delete;
~AstOptimizer() = default;
StatementPtr Optimise(Statement& statement);
StatementPtr Optimise(Statement& statement, UInt64 enabledConditions);
inline ExpressionPtr Optimise(Expression& expression);
inline ExpressionPtr Optimise(Expression& expression, const Options& options);
inline StatementPtr Optimise(Statement& statement);
inline StatementPtr Optimise(Statement& statement, const Options& options);
AstOptimizer& operator=(const AstOptimizer&) = delete;
AstOptimizer& operator=(AstOptimizer&&) = delete;
struct Options
{
std::function<const ConstantValue&(std::size_t constantId)> constantQueryCallback;
std::optional<UInt64> enabledOptions = 0;
};
protected:
ExpressionPtr Clone(BinaryExpression& node) override;
ExpressionPtr Clone(CastExpression& node) override;
ExpressionPtr Clone(ConditionalExpression& node) override;
ExpressionPtr Clone(ConstantIndexExpression& node) override;
ExpressionPtr Clone(UnaryExpression& node) override;
StatementPtr Clone(BranchStatement& node) override;
StatementPtr Clone(ConditionalStatement& node) override;
@@ -45,11 +56,13 @@ namespace Nz::ShaderAst
template<typename TargetType> ExpressionPtr PropagateVec4Cast(TargetType v1, TargetType v2, TargetType v3, TargetType v4);
private:
std::optional<UInt64> m_enabledOptions;
Options m_options;
};
inline ExpressionPtr Optimize(Expression& expr);
inline ExpressionPtr Optimize(Expression& expr, const AstOptimizer::Options& options);
inline StatementPtr Optimize(Statement& ast);
inline StatementPtr Optimize(Statement& ast, UInt64 enabledConditions);
inline StatementPtr Optimize(Statement& ast, const AstOptimizer::Options& options);
}
#include <Nazara/Shader/Ast/AstOptimizer.inl>

View File

@@ -7,16 +7,52 @@
namespace Nz::ShaderAst
{
inline ExpressionPtr AstOptimizer::Optimise(Expression& expression)
{
m_options = {};
return CloneExpression(expression);
}
inline ExpressionPtr AstOptimizer::Optimise(Expression& expression, const Options& options)
{
m_options = options;
return CloneExpression(expression);
}
inline StatementPtr AstOptimizer::Optimise(Statement& statement)
{
m_options = {};
return CloneStatement(statement);
}
inline StatementPtr AstOptimizer::Optimise(Statement& statement, const Options& options)
{
m_options = options;
return CloneStatement(statement);
}
inline ExpressionPtr Optimize(Expression& ast)
{
AstOptimizer optimize;
return optimize.Optimise(ast);
}
inline ExpressionPtr Optimize(Expression& ast, const AstOptimizer::Options& options)
{
AstOptimizer optimize;
return optimize.Optimise(ast, options);
}
inline StatementPtr Optimize(Statement& ast)
{
AstOptimizer optimize;
return optimize.Optimise(ast);
}
inline StatementPtr Optimize(Statement& ast, UInt64 enabledConditions)
inline StatementPtr Optimize(Statement& ast, const AstOptimizer::Options& options)
{
AstOptimizer optimize;
return optimize.Optimise(ast, enabledConditions);
return optimize.Optimise(ast, options);
}
}

View File

@@ -29,6 +29,7 @@ namespace Nz::ShaderAst
void Visit(CastExpression& node) override;
void Visit(ConditionalExpression& node) override;
void Visit(ConstantExpression& node) override;
void Visit(ConstantIndexExpression& node) override;
void Visit(IdentifierExpression& node) override;
void Visit(IntrinsicExpression& node) override;
void Visit(SelectOptionExpression& node) override;

View File

@@ -30,6 +30,7 @@ namespace Nz::ShaderAst
void Serialize(CallFunctionExpression& node);
void Serialize(CallMethodExpression& node);
void Serialize(CastExpression& node);
void Serialize(ConstantIndexExpression& node);
void Serialize(ConditionalExpression& node);
void Serialize(ConstantExpression& node);
void Serialize(IdentifierExpression& node);
@@ -53,6 +54,7 @@ namespace Nz::ShaderAst
void Serialize(ReturnStatement& node);
protected:
template<typename T> void Attribute(AttributeValue<T>& attribute);
template<typename T> void Container(T& container);
template<typename T> void Enum(T& enumVal);
template<typename T> void OptEnum(std::optional<T>& optVal);

View File

@@ -7,6 +7,73 @@
namespace Nz::ShaderAst
{
template<typename T>
void AstSerializerBase::Attribute(AttributeValue<T>& attribute)
{
UInt32 valueType;
if (IsWriting())
{
if (!attribute.HasValue())
valueType = 0;
else if (attribute.IsExpression())
valueType = 1;
else if (attribute.IsResultingValue())
valueType = 2;
else
throw std::runtime_error("unexpected attribute");
}
Value(valueType);
switch (valueType)
{
case 0:
if (!IsWriting())
attribute = {};
break;
case 1:
{
if (!IsWriting())
{
ExpressionPtr expr;
Node(expr);
attribute = std::move(expr);
}
else
Node(const_cast<ExpressionPtr&>(attribute.GetExpression())); //< not used for writing
break;
}
case 2:
{
if (!IsWriting())
{
T value;
if constexpr (std::is_enum_v<T>)
Enum(value);
else
Value(value);
attribute = std::move(value);
}
else
{
T& value = const_cast<T&>(attribute.GetResultingValue()); //< not used for writing
if constexpr (std::is_enum_v<T>)
Enum(value);
else
Value(value);
}
break;
}
}
}
template<typename T>
void AstSerializerBase::Container(T& container)
{

View File

@@ -40,6 +40,7 @@ namespace Nz::ShaderAst
void Visit(CastExpression& node) override;
void Visit(ConditionalExpression& node) override;
void Visit(ConstantExpression& node) override;
void Visit(ConstantIndexExpression& node) override;
void Visit(IdentifierExpression& node) override;
void Visit(IntrinsicExpression& node) override;
void Visit(SelectOptionExpression& node) override;

View File

@@ -9,17 +9,52 @@
#include <Nazara/Prerequisites.hpp>
#include <Nazara/Shader/Ast/Enums.hpp>
#include <memory>
#include <optional>
#include <variant>
namespace Nz::ShaderAst
{
struct Expression;
using ExpressionPtr = std::unique_ptr<Expression>;
template<typename T>
class AttributeValue
{
public:
AttributeValue() = default;
AttributeValue(T value);
AttributeValue(ExpressionPtr expr);
AttributeValue(const AttributeValue&) = default;
AttributeValue(AttributeValue&&) = default;
~AttributeValue() = default;
ExpressionPtr&& GetExpression() &&;
const ExpressionPtr& GetExpression() const &;
const T& GetResultingValue() const;
bool IsExpression() const;
bool IsResultingValue() const;
bool HasValue() const;
AttributeValue& operator=(const AttributeValue&) = default;
AttributeValue& operator=(AttributeValue&&) = default;
private:
std::variant<std::monostate, T, ExpressionPtr> m_value;
};
struct Attribute
{
using Param = std::variant<std::monostate, long long, std::string>;
using Param = std::optional<ExpressionPtr>;
AttributeType type;
Param args;
};
}
#include <Nazara/Shader/Ast/Attribute.inl>
#endif

View File

@@ -0,0 +1,72 @@
// Copyright (C) 2020 Jérôme Leclercq
// This file is part of the "Nazara Engine - Shader generator"
// For conditions of distribution and use, see copyright notice in Config.hpp
#include <Nazara/Shader/Ast/Attribute.hpp>
#include <cassert>
#include <stdexcept>
#include <Nazara/Shader/Debug.hpp>
namespace Nz::ShaderAst
{
template<typename T>
AttributeValue<T>::AttributeValue(T value) :
m_value(std::move(value))
{
}
template<typename T>
AttributeValue<T>::AttributeValue(ExpressionPtr expr)
{
assert(expr);
m_value = std::move(expr);
}
template<typename T>
ExpressionPtr&& AttributeValue<T>::GetExpression() &&
{
if (!IsExpression())
throw std::runtime_error("excepted expression");
return std::get<ExpressionPtr>(std::move(m_value));
}
template<typename T>
const ExpressionPtr& AttributeValue<T>::GetExpression() const &
{
if (!IsExpression())
throw std::runtime_error("excepted expression");
assert(std::get<ExpressionPtr>(m_value));
return std::get<ExpressionPtr>(m_value);
}
template<typename T>
const T& AttributeValue<T>::GetResultingValue() const
{
if (!IsResultingValue())
throw std::runtime_error("excepted resulting value");
return std::get<T>(m_value);
}
template<typename T>
bool AttributeValue<T>::IsExpression() const
{
return std::holds_alternative<ExpressionPtr>(m_value);
}
template<typename T>
bool AttributeValue<T>::IsResultingValue() const
{
return std::holds_alternative<T>(m_value);
}
template<typename T>
bool AttributeValue<T>::HasValue() const
{
return !std::holds_alternative<std::monostate>(m_value);
}
}
#include <Nazara/Shader/DebugOff.hpp>

View File

@@ -8,6 +8,7 @@
#define NAZARA_SHADER_CONSTANTVALUE_HPP
#include <Nazara/Prerequisites.hpp>
#include <Nazara/Core/TypeList.hpp>
#include <Nazara/Math/Vector2.hpp>
#include <Nazara/Math/Vector3.hpp>
#include <Nazara/Math/Vector4.hpp>
@@ -17,7 +18,7 @@
namespace Nz::ShaderAst
{
using ConstantValue = std::variant<
using ConstantTypes = TypeList<
bool,
float,
Int32,
@@ -30,6 +31,8 @@ namespace Nz::ShaderAst
Vector4i32
>;
using ConstantValue = TypeListInstantiate<ConstantTypes, std::variant>;
NAZARA_SHADER_API ExpressionType GetExpressionType(const ConstantValue& constant);
}

View File

@@ -23,12 +23,12 @@ namespace Nz
{
Binding, //< Binding (external var only) - has argument index
Builtin, //< Builtin (struct member only) - has argument type
Cond, //< Conditional compilation option - has argument expr
DepthWrite, //< Depth write mode (function only) - has argument type
EarlyFragmentTests, //< Entry point (function only) - has argument on/off
Entry, //< Entry point (function only) - has argument type
Layout, //< Struct layout (struct only) - has argument style
Location, //< Location (struct member only) - has argument index
Option, //< Conditional compilation option - has argument expr
Set, //< Binding set (external var only) - has argument index
};

View File

@@ -82,13 +82,14 @@ namespace Nz::ShaderAst
{
struct StructMember
{
std::optional<BuiltinEntry> builtin;
std::optional<UInt32> locationIndex;
AttributeValue<BuiltinEntry> builtin;
AttributeValue<bool> cond;
AttributeValue<UInt32> locationIndex;
std::string name;
ExpressionType type;
};
std::optional<StructLayout> layout;
AttributeValue<StructLayout> layout;
std::string name;
std::vector<StructMember> members;
};

View File

@@ -64,7 +64,7 @@ namespace Nz::ShaderAst
std::optional<ExpressionType> cachedExpressionType;
};
struct NAZARA_SHADER_API AccessIdentifierExpression : public Expression
struct NAZARA_SHADER_API AccessIdentifierExpression : Expression
{
NodeType GetType() const override;
void Visit(AstExpressionVisitor& visitor) override;
@@ -73,7 +73,7 @@ namespace Nz::ShaderAst
std::vector<std::string> identifiers;
};
struct NAZARA_SHADER_API AccessIndexExpression : public Expression
struct NAZARA_SHADER_API AccessIndexExpression : Expression
{
NodeType GetType() const override;
void Visit(AstExpressionVisitor& visitor) override;
@@ -82,7 +82,7 @@ namespace Nz::ShaderAst
std::vector<ExpressionPtr> indices;
};
struct NAZARA_SHADER_API AssignExpression : public Expression
struct NAZARA_SHADER_API AssignExpression : Expression
{
NodeType GetType() const override;
void Visit(AstExpressionVisitor& visitor) override;
@@ -92,7 +92,7 @@ namespace Nz::ShaderAst
ExpressionPtr right;
};
struct NAZARA_SHADER_API BinaryExpression : public Expression
struct NAZARA_SHADER_API BinaryExpression : Expression
{
NodeType GetType() const override;
void Visit(AstExpressionVisitor& visitor) override;
@@ -102,7 +102,7 @@ namespace Nz::ShaderAst
ExpressionPtr right;
};
struct NAZARA_SHADER_API CallFunctionExpression : public Expression
struct NAZARA_SHADER_API CallFunctionExpression : Expression
{
NodeType GetType() const override;
void Visit(AstExpressionVisitor& visitor) override;
@@ -111,7 +111,7 @@ namespace Nz::ShaderAst
std::vector<ExpressionPtr> parameters;
};
struct NAZARA_SHADER_API CallMethodExpression : public Expression
struct NAZARA_SHADER_API CallMethodExpression : Expression
{
NodeType GetType() const override;
void Visit(AstExpressionVisitor& visitor) override;
@@ -121,7 +121,7 @@ namespace Nz::ShaderAst
std::vector<ExpressionPtr> parameters;
};
struct NAZARA_SHADER_API CastExpression : public Expression
struct NAZARA_SHADER_API CastExpression : Expression
{
NodeType GetType() const override;
void Visit(AstExpressionVisitor& visitor) override;
@@ -130,17 +130,17 @@ namespace Nz::ShaderAst
std::array<ExpressionPtr, 4> expressions;
};
struct NAZARA_SHADER_API ConditionalExpression : public Expression
struct NAZARA_SHADER_API ConditionalExpression : Expression
{
NodeType GetType() const override;
void Visit(AstExpressionVisitor& visitor) override;
std::size_t optionIndex;
ExpressionPtr condition;
ExpressionPtr falsePath;
ExpressionPtr truePath;
};
struct NAZARA_SHADER_API ConstantExpression : public Expression
struct NAZARA_SHADER_API ConstantExpression : Expression
{
NodeType GetType() const override;
void Visit(AstExpressionVisitor& visitor) override;
@@ -148,7 +148,15 @@ namespace Nz::ShaderAst
ShaderAst::ConstantValue value;
};
struct NAZARA_SHADER_API IdentifierExpression : public Expression
struct NAZARA_SHADER_API ConstantIndexExpression : Expression
{
NodeType GetType() const override;
void Visit(AstExpressionVisitor& visitor) override;
std::size_t constantId;
};
struct NAZARA_SHADER_API IdentifierExpression : Expression
{
NodeType GetType() const override;
void Visit(AstExpressionVisitor& visitor) override;
@@ -156,7 +164,7 @@ namespace Nz::ShaderAst
std::string identifier;
};
struct NAZARA_SHADER_API IntrinsicExpression : public Expression
struct NAZARA_SHADER_API IntrinsicExpression : Expression
{
NodeType GetType() const override;
void Visit(AstExpressionVisitor& visitor) override;
@@ -165,7 +173,7 @@ namespace Nz::ShaderAst
std::vector<ExpressionPtr> parameters;
};
struct NAZARA_SHADER_API SelectOptionExpression : public Expression
struct NAZARA_SHADER_API SelectOptionExpression : Expression
{
NodeType GetType() const override;
void Visit(AstExpressionVisitor& visitor) override;
@@ -175,7 +183,7 @@ namespace Nz::ShaderAst
ExpressionPtr truePath;
};
struct NAZARA_SHADER_API SwizzleExpression : public Expression
struct NAZARA_SHADER_API SwizzleExpression : Expression
{
NodeType GetType() const override;
void Visit(AstExpressionVisitor& visitor) override;
@@ -193,7 +201,7 @@ namespace Nz::ShaderAst
std::size_t variableId;
};
struct NAZARA_SHADER_API UnaryExpression : public Expression
struct NAZARA_SHADER_API UnaryExpression : Expression
{
NodeType GetType() const override;
void Visit(AstExpressionVisitor& visitor) override;
@@ -221,7 +229,7 @@ namespace Nz::ShaderAst
Statement& operator=(Statement&&) noexcept = default;
};
struct NAZARA_SHADER_API BranchStatement : public Statement
struct NAZARA_SHADER_API BranchStatement : Statement
{
NodeType GetType() const override;
void Visit(AstStatementVisitor& visitor) override;
@@ -241,7 +249,7 @@ namespace Nz::ShaderAst
NodeType GetType() const override;
void Visit(AstStatementVisitor& visitor) override;
std::size_t optionIndex;
ExpressionPtr condition;
StatementPtr statement;
};
@@ -252,12 +260,13 @@ namespace Nz::ShaderAst
struct ExternalVar
{
std::optional<UInt32> bindingIndex;
std::optional<UInt32> bindingSet;
AttributeValue<UInt32> bindingIndex;
AttributeValue<UInt32> bindingSet;
std::string name;
ExpressionType type;
};
AttributeValue<UInt32> bindingSet;
std::optional<std::size_t> varIndex;
std::vector<ExternalVar> externalVars;
};
@@ -273,12 +282,11 @@ namespace Nz::ShaderAst
ExpressionType type;
};
std::optional<DepthWriteMode> depthWrite;
std::optional<bool> earlyFragmentTests;
std::optional<ShaderStageType> entryStage;
AttributeValue<DepthWriteMode> depthWrite;
AttributeValue<bool> earlyFragmentTests;
AttributeValue<ShaderStageType> entryStage;
std::optional<std::size_t> funcIndex;
std::optional<std::size_t> varIndex;
std::string optionName;
std::string name;
std::vector<Parameter> parameters;
std::vector<StatementPtr> statements;

View File

@@ -36,6 +36,7 @@ namespace Nz::ShaderAst
struct Options
{
std::unordered_set<std::string> reservedIdentifiers;
UInt64 enabledOptions = 0;
bool makeVariableNameUnique = false;
bool removeOptionDeclaration = true;
};
@@ -71,7 +72,7 @@ namespace Nz::ShaderAst
StatementPtr Clone(ExpressionStatement& node) override;
StatementPtr Clone(MultiStatement& node) override;
inline const Identifier* FindIdentifier(const std::string_view& identifierName) const;
const Identifier* FindIdentifier(const std::string_view& identifierName) const;
Expression& MandatoryExpr(ExpressionPtr& node);
Statement& MandatoryStatement(StatementPtr& node);
@@ -81,14 +82,17 @@ namespace Nz::ShaderAst
void PushScope();
void PopScope();
template<typename T> const T& ComputeAttributeValue(AttributeValue<T>& attribute);
ConstantValue ComputeConstantValue(Expression& expr);
std::size_t DeclareFunction(DeclareFunctionStatement& funcDecl);
void PropagateFunctionFlags(std::size_t funcIndex, FunctionFlags flags, Bitset<>& seen);
std::size_t RegisterConstant(std::string name, ConstantValue value);
FunctionData& RegisterFunction(std::size_t functionIndex);
std::size_t RegisterIntrinsic(std::string name, IntrinsicType type);
std::size_t RegisterOption(std::string name, ExpressionType type);
std::size_t RegisterStruct(std::string name, StructDescription description);
std::size_t RegisterStruct(std::string name, StructDescription* description);
std::size_t RegisterVariable(std::string name, ExpressionType type);
void ResolveFunctions();
@@ -118,9 +122,9 @@ namespace Nz::ShaderAst
enum class Type
{
Alias,
Constant,
Function,
Intrinsic,
Option,
Struct,
Variable
};
@@ -130,14 +134,6 @@ namespace Nz::ShaderAst
Type type;
};
std::vector<Identifier> m_identifiersInScope;
std::vector<FunctionData> m_functions;
std::vector<IntrinsicType> m_intrinsics;
std::vector<ExpressionType> m_options;
std::vector<StructDescription> m_structs;
std::vector<ExpressionType> m_variableTypes;
std::vector<std::size_t> m_scopeSizes;
struct Context;
Context* m_context;
};

View File

@@ -12,15 +12,6 @@ namespace Nz::ShaderAst
return Sanitize(statement, {}, error);
}
inline auto SanitizeVisitor::FindIdentifier(const std::string_view& identifierName) const -> const Identifier*
{
auto it = std::find_if(m_identifiersInScope.rbegin(), m_identifiersInScope.rend(), [&](const Identifier& identifier) { return identifier.name == identifierName; });
if (it == m_identifiersInScope.rend())
return nullptr;
return &*it;
}
inline StatementPtr Sanitize(Statement& ast, std::string* error)
{
SanitizeVisitor sanitizer;