Shader: Handle type as expressions

This commit is contained in:
Jérôme Leclercq
2022-02-08 17:03:34 +01:00
parent 5ce8120a0c
commit 402e16bd2b
53 changed files with 1746 additions and 1141 deletions

View File

@@ -24,7 +24,8 @@ namespace Nz::ShaderAst
AstCloner(AstCloner&&) = delete;
~AstCloner() = default;
template<typename T> AttributeValue<T> Clone(const AttributeValue<T>& attribute);
template<typename T> ExpressionValue<T> Clone(const ExpressionValue<T>& expressionValue);
inline ExpressionValue<ExpressionType> Clone(const ExpressionValue<ExpressionType>& expressionValue);
ExpressionPtr Clone(Expression& statement);
StatementPtr Clone(Statement& statement);
@@ -37,6 +38,7 @@ namespace Nz::ShaderAst
virtual ExpressionPtr CloneExpression(Expression& expr);
virtual StatementPtr CloneStatement(Statement& statement);
virtual ExpressionValue<ExpressionType> CloneType(const ExpressionValue<ExpressionType>& exprType);
virtual ExpressionPtr Clone(AccessIdentifierExpression& node);
virtual ExpressionPtr Clone(AccessIndexExpression& node);
@@ -69,6 +71,7 @@ namespace Nz::ShaderAst
virtual StatementPtr Clone(MultiStatement& node);
virtual StatementPtr Clone(NoOpStatement& node);
virtual StatementPtr Clone(ReturnStatement& node);
virtual StatementPtr Clone(ScopedStatement& node);
virtual StatementPtr Clone(WhileStatement& node);
#define NAZARA_SHADERAST_NODE(NodeType) void Visit(NodeType& node) override;
@@ -85,7 +88,7 @@ namespace Nz::ShaderAst
std::vector<StatementPtr> m_statementStack;
};
template<typename T> AttributeValue<T> Clone(const AttributeValue<T>& attribute);
template<typename T> ExpressionValue<T> Clone(const ExpressionValue<T>& attribute);
inline ExpressionPtr Clone(Expression& node);
inline StatementPtr Clone(Statement& node);
}

View File

@@ -8,20 +8,25 @@
namespace Nz::ShaderAst
{
template<typename T>
AttributeValue<T> AstCloner::Clone(const AttributeValue<T>& attribute)
ExpressionValue<T> AstCloner::Clone(const ExpressionValue<T>& expressionValue)
{
if (!attribute.HasValue())
if (!expressionValue.HasValue())
return {};
if (attribute.IsExpression())
return CloneExpression(attribute.GetExpression());
if (expressionValue.IsExpression())
return CloneExpression(expressionValue.GetExpression());
else
{
assert(attribute.IsResultingValue());
return attribute.GetResultingValue();
assert(expressionValue.IsResultingValue());
return expressionValue.GetResultingValue();
}
}
inline ExpressionValue<ExpressionType> AstCloner::Clone(const ExpressionValue<ExpressionType>& expressionValue)
{
return CloneType(expressionValue);
}
ExpressionPtr AstCloner::CloneExpression(const ExpressionPtr& expr)
{
if (!expr)
@@ -40,7 +45,7 @@ namespace Nz::ShaderAst
template<typename T>
AttributeValue<T> Clone(const AttributeValue<T>& attribute)
ExpressionValue<T> Clone(const ExpressionValue<T>& attribute)
{
AstCloner cloner;
return cloner.Clone(attribute);

View File

@@ -21,7 +21,7 @@ namespace Nz::ShaderAst
template<typename T> bool Compare(const T& lhs, const T& rhs);
template<typename T, std::size_t S> bool Compare(const std::array<T, S>& lhs, const std::array<T, S>& rhs);
template<typename T> bool Compare(const std::vector<T>& lhs, const std::vector<T>& rhs);
template<typename T> bool Compare(const AttributeValue<T>& lhs, const AttributeValue<T>& rhs);
template<typename T> bool Compare(const ExpressionValue<T>& lhs, const ExpressionValue<T>& rhs);
inline bool Compare(const BranchStatement::ConditionalStatement& lhs, const BranchStatement::ConditionalStatement& rhs);
inline bool Compare(const DeclareExternalStatement::ExternalVar& lhs, const DeclareExternalStatement::ExternalVar& rhs);
inline bool Compare(const DeclareFunctionStatement::Parameter& lhs, const DeclareFunctionStatement::Parameter& rhs);
@@ -59,6 +59,7 @@ namespace Nz::ShaderAst
inline bool Compare(const MultiStatement& lhs, const MultiStatement& rhs);
inline bool Compare(const NoOpStatement& lhs, const NoOpStatement& rhs);
inline bool Compare(const ReturnStatement& lhs, const ReturnStatement& rhs);
inline bool Compare(const ScopedStatement& lhs, const ScopedStatement& rhs);
inline bool Compare(const WhileStatement& lhs, const WhileStatement& rhs);
}

View File

@@ -78,7 +78,7 @@ namespace Nz::ShaderAst
}
template<typename T>
bool Compare(const AttributeValue<T>& lhs, const AttributeValue<T>& rhs)
bool Compare(const ExpressionValue<T>& lhs, const ExpressionValue<T>& rhs)
{
if (!Compare(lhs.HasValue(), rhs.HasValue()))
return false;
@@ -519,6 +519,14 @@ namespace Nz::ShaderAst
return true;
}
bool Compare(const ScopedStatement& lhs, const ScopedStatement& rhs)
{
if (!Compare(lhs.statement, rhs.statement))
return false;
return true;
}
inline bool Compare(const WhileStatement& lhs, const WhileStatement& rhs)
{
if (!Compare(lhs.unroll, rhs.unroll))

View File

@@ -13,7 +13,7 @@
namespace Nz::ShaderAst
{
class NAZARA_SHADER_API ExpressionVisitorExcept : public AstExpressionVisitor
class NAZARA_SHADER_API AstExpressionVisitorExcept : public AstExpressionVisitor
{
public:
using AstExpressionVisitor::Visit;

View File

@@ -58,6 +58,7 @@ NAZARA_SHADERAST_STATEMENT(ExpressionStatement)
NAZARA_SHADERAST_STATEMENT(MultiStatement)
NAZARA_SHADERAST_STATEMENT(NoOpStatement)
NAZARA_SHADERAST_STATEMENT(ReturnStatement)
NAZARA_SHADERAST_STATEMENT(ScopedStatement)
NAZARA_SHADERAST_STATEMENT_LAST(WhileStatement)
#undef NAZARA_SHADERAST_EXPRESSION

View File

@@ -57,6 +57,8 @@ namespace Nz::ShaderAst
template<typename TargetType> ExpressionPtr PropagateVec3Cast(TargetType v1, TargetType v2, TargetType v3);
template<typename TargetType> ExpressionPtr PropagateVec4Cast(TargetType v1, TargetType v2, TargetType v3, TargetType v4);
StatementPtr Unscope(StatementPtr node);
private:
Options m_options;
};

View File

@@ -51,6 +51,7 @@ namespace Nz::ShaderAst
void Visit(MultiStatement& node) override;
void Visit(NoOpStatement& node) override;
void Visit(ReturnStatement& node) override;
void Visit(ScopedStatement& node) override;
void Visit(WhileStatement& node) override;
};
}

View File

@@ -32,7 +32,7 @@ namespace Nz::ShaderAst
struct Callbacks
{
std::function<void(ShaderStageType stageType, const std::string& functionName)> onEntryPointDeclaration;
std::function<void(const std::string& optionName, const ExpressionType& optionType)> onOptionDeclaration;
std::function<void(const std::string& optionName, const ExpressionValue<ExpressionType>& optionType)> onOptionDeclaration;
};
private:

View File

@@ -54,12 +54,13 @@ namespace Nz::ShaderAst
void Serialize(MultiStatement& node);
void Serialize(NoOpStatement& node);
void Serialize(ReturnStatement& node);
void Serialize(ScopedStatement& node);
void Serialize(WhileStatement& node);
protected:
template<typename T> void Attribute(AttributeValue<T>& attribute);
template<typename T> void Container(T& container);
template<typename T> void Enum(T& enumVal);
template<typename T> void ExprValue(ExpressionValue<T>& attribute);
template<typename T> void OptEnum(std::optional<T>& optVal);
template<typename T> void OptVal(std::optional<T>& optVal);

View File

@@ -3,12 +3,42 @@
// For conditions of distribution and use, see copyright notice in Config.hpp
#include <Nazara/Shader/Ast/AstSerializer.hpp>
#include <Nazara/Core/Algorithm.hpp>
#include <Nazara/Shader/Debug.hpp>
namespace Nz::ShaderAst
{
template<typename T>
void AstSerializerBase::Attribute(AttributeValue<T>& attribute)
void AstSerializerBase::Container(T& container)
{
bool isWriting = IsWriting();
UInt32 size;
if (isWriting)
size = SafeCast<UInt32>(container.size());
Value(size);
if (!isWriting)
container.resize(size);
}
template<typename T>
void AstSerializerBase::Enum(T& enumVal)
{
bool isWriting = IsWriting();
UInt32 value;
if (isWriting)
value = SafeCast<UInt32>(enumVal);
Value(value);
if (!isWriting)
enumVal = static_cast<T>(value);
}
template<typename T>
void AstSerializerBase::ExprValue(ExpressionValue<T>& attribute)
{
UInt32 valueType;
if (IsWriting())
@@ -55,6 +85,8 @@ namespace Nz::ShaderAst
T value;
if constexpr (std::is_enum_v<T>)
Enum(value);
else if constexpr (std::is_same_v<T, ExpressionType>)
Type(value);
else
Value(value);
@@ -65,6 +97,8 @@ namespace Nz::ShaderAst
T& value = const_cast<T&>(attribute.GetResultingValue()); //< not used for writing
if constexpr (std::is_enum_v<T>)
Enum(value);
else if constexpr (std::is_same_v<T, ExpressionType>)
Type(value);
else
Value(value);
}
@@ -74,35 +108,6 @@ namespace Nz::ShaderAst
}
}
template<typename T>
void AstSerializerBase::Container(T& container)
{
bool isWriting = IsWriting();
UInt32 size;
if (isWriting)
size = UInt32(container.size());
Value(size);
if (!isWriting)
container.resize(size);
}
template<typename T>
void AstSerializerBase::Enum(T& enumVal)
{
bool isWriting = IsWriting();
UInt32 value;
if (isWriting)
value = static_cast<UInt32>(enumVal);
Value(value);
if (!isWriting)
enumVal = static_cast<T>(value);
}
template<typename T>
void AstSerializerBase::OptEnum(std::optional<T>& optVal)
{
@@ -150,12 +155,12 @@ namespace Nz::ShaderAst
UInt32 fixedVal;
if (isWriting)
fixedVal = static_cast<UInt32>(val);
fixedVal = SafeCast<UInt32>(val);
Value(fixedVal);
if (!isWriting)
val = static_cast<std::size_t>(fixedVal);
val = SafeCast<std::size_t>(fixedVal);
}
inline ShaderAstSerializer::ShaderAstSerializer(ByteStream& stream) :

View File

@@ -13,7 +13,7 @@
namespace Nz::ShaderAst
{
class NAZARA_SHADER_API StatementVisitorExcept : public AstStatementVisitor
class NAZARA_SHADER_API AstStatementVisitorExcept : public AstStatementVisitor
{
public:
using AstStatementVisitor::Visit;

View File

@@ -0,0 +1,36 @@
// Copyright (C) 2022 Jérôme "Lynix" Leclercq (lynix680@gmail.com)
// This file is part of the "Nazara Engine - Shader module"
// For conditions of distribution and use, see copyright notice in Config.hpp
#pragma once
#ifndef NAZARA_SHADER_AST_ASTTYPES_HPP
#define NAZARA_SHADER_AST_ASTTYPES_HPP
#include <Nazara/Shader/Ast/ConstantValue.hpp>
#include <Nazara/Shader/Ast/ExpressionType.hpp>
#include <functional>
namespace Nz::ShaderAst
{
enum class TypeParameterCategory
{
ConstantValue,
FullType,
PrimitiveType,
StructType
};
struct PartialType;
using TypeParameter = std::variant<ConstantValue, ExpressionType, PartialType>;
struct PartialType
{
std::vector<TypeParameterCategory> parameters;
std::function<ExpressionType(const TypeParameter* parameters, std::size_t parameterCount)> buildFunc;
};
}
#endif // NAZARA_SHADER_AST_ASTTYPES_HPP

View File

@@ -20,15 +20,15 @@ namespace Nz::ShaderAst
using ExpressionPtr = std::unique_ptr<Expression>;
template<typename T>
class AttributeValue
class ExpressionValue
{
public:
AttributeValue() = default;
AttributeValue(T value);
AttributeValue(ExpressionPtr expr);
AttributeValue(const AttributeValue&) = default;
AttributeValue(AttributeValue&&) = default;
~AttributeValue() = default;
ExpressionValue() = default;
ExpressionValue(T value);
ExpressionValue(ExpressionPtr expr);
ExpressionValue(const ExpressionValue&) = default;
ExpressionValue(ExpressionValue&&) noexcept = default;
~ExpressionValue() = default;
ExpressionPtr&& GetExpression() &&;
const ExpressionPtr& GetExpression() const &;
@@ -39,14 +39,14 @@ namespace Nz::ShaderAst
bool HasValue() const;
AttributeValue& operator=(const AttributeValue&) = default;
AttributeValue& operator=(AttributeValue&&) = default;
ExpressionValue& operator=(const ExpressionValue&) = default;
ExpressionValue& operator=(ExpressionValue&&) noexcept = default;
private:
std::variant<std::monostate, T, ExpressionPtr> m_value;
};
struct Attribute
struct ExprValue
{
using Param = std::optional<ExpressionPtr>;

View File

@@ -10,20 +10,20 @@
namespace Nz::ShaderAst
{
template<typename T>
AttributeValue<T>::AttributeValue(T value) :
ExpressionValue<T>::ExpressionValue(T value) :
m_value(std::move(value))
{
}
template<typename T>
AttributeValue<T>::AttributeValue(ExpressionPtr expr)
ExpressionValue<T>::ExpressionValue(ExpressionPtr expr)
{
assert(expr);
m_value = std::move(expr);
}
template<typename T>
ExpressionPtr&& AttributeValue<T>::GetExpression() &&
ExpressionPtr&& ExpressionValue<T>::GetExpression() &&
{
if (!IsExpression())
throw std::runtime_error("excepted expression");
@@ -32,7 +32,7 @@ namespace Nz::ShaderAst
}
template<typename T>
const ExpressionPtr& AttributeValue<T>::GetExpression() const &
const ExpressionPtr& ExpressionValue<T>::GetExpression() const &
{
if (!IsExpression())
throw std::runtime_error("excepted expression");
@@ -42,7 +42,7 @@ namespace Nz::ShaderAst
}
template<typename T>
const T& AttributeValue<T>::GetResultingValue() const
const T& ExpressionValue<T>::GetResultingValue() const
{
if (!IsResultingValue())
throw std::runtime_error("excepted resulting value");
@@ -51,19 +51,19 @@ namespace Nz::ShaderAst
}
template<typename T>
bool AttributeValue<T>::IsExpression() const
bool ExpressionValue<T>::IsExpression() const
{
return std::holds_alternative<ExpressionPtr>(m_value);
}
template<typename T>
bool AttributeValue<T>::IsResultingValue() const
bool ExpressionValue<T>::IsResultingValue() const
{
return std::holds_alternative<T>(m_value);
}
template<typename T>
bool AttributeValue<T>::HasValue() const
bool ExpressionValue<T>::HasValue() const
{
return !std::holds_alternative<std::monostate>(m_value);
}

View File

@@ -29,14 +29,22 @@ namespace Nz::ShaderAst
ArrayType& operator=(const ArrayType& array);
ArrayType& operator=(ArrayType&&) noexcept = default;
AttributeValue<UInt32> length;
UInt32 length;
std::unique_ptr<ContainedType> containedType;
bool operator==(const ArrayType& rhs) const;
inline bool operator!=(const ArrayType& rhs) const;
};
struct IdentifierType //< Alias or struct
struct FunctionType
{
std::size_t funcIndex;
inline bool operator==(const FunctionType& rhs) const;
inline bool operator!=(const FunctionType& rhs) const;
};
struct IdentifierType
{
std::string name;
@@ -44,6 +52,14 @@ namespace Nz::ShaderAst
inline bool operator!=(const IdentifierType& rhs) const;
};
struct IntrinsicFunctionType
{
IntrinsicType intrinsic;
inline bool operator==(const IntrinsicFunctionType& rhs) const;
inline bool operator!=(const IntrinsicFunctionType& rhs) const;
};
struct MatrixType
{
std::size_t columnCount;
@@ -54,6 +70,22 @@ namespace Nz::ShaderAst
inline bool operator!=(const MatrixType& rhs) const;
};
struct NAZARA_SHADER_API MethodType
{
MethodType() = default;
MethodType(const MethodType& methodType);
MethodType(MethodType&&) noexcept = default;
MethodType& operator=(const MethodType& methodType);
MethodType& operator=(MethodType&&) noexcept = default;
std::unique_ptr<ContainedType> objectType;
std::size_t methodIndex;
bool operator==(const MethodType& rhs) const;
inline bool operator!=(const MethodType& rhs) const;
};
struct NoType
{
inline bool operator==(const NoType& rhs) const;
@@ -77,9 +109,17 @@ namespace Nz::ShaderAst
inline bool operator!=(const StructType& rhs) const;
};
struct Type
{
std::size_t typeIndex;
inline bool operator==(const Type& rhs) const;
inline bool operator!=(const Type& rhs) const;
};
struct UniformType
{
std::variant<IdentifierType, StructType> containedType;
StructType containedType;
inline bool operator==(const UniformType& rhs) const;
inline bool operator!=(const UniformType& rhs) const;
@@ -94,7 +134,7 @@ namespace Nz::ShaderAst
inline bool operator!=(const VectorType& rhs) const;
};
using ExpressionType = std::variant<NoType, ArrayType, IdentifierType, PrimitiveType, MatrixType, SamplerType, StructType, UniformType, VectorType>;
using ExpressionType = std::variant<NoType, ArrayType, FunctionType, IdentifierType, IntrinsicFunctionType, PrimitiveType, MatrixType, MethodType, SamplerType, StructType, Type, UniformType, VectorType>;
struct ContainedType
{
@@ -105,25 +145,29 @@ namespace Nz::ShaderAst
{
struct StructMember
{
AttributeValue<BuiltinEntry> builtin;
AttributeValue<bool> cond;
AttributeValue<UInt32> locationIndex;
ExpressionValue<BuiltinEntry> builtin;
ExpressionValue<bool> cond;
ExpressionValue<UInt32> locationIndex;
ExpressionValue<ExpressionType> type;
std::string name;
ExpressionType type;
};
AttributeValue<StructLayout> layout;
ExpressionValue<StructLayout> layout;
std::string name;
std::vector<StructMember> members;
};
inline bool IsArrayType(const ExpressionType& type);
inline bool IsFunctionType(const ExpressionType& type);
inline bool IsIdentifierType(const ExpressionType& type);
inline bool IsIntrinsicFunctionType(const ExpressionType& type);
inline bool IsMatrixType(const ExpressionType& type);
inline bool IsMethodType(const ExpressionType& type);
inline bool IsNoType(const ExpressionType& type);
inline bool IsPrimitiveType(const ExpressionType& type);
inline bool IsSamplerType(const ExpressionType& type);
inline bool IsStructType(const ExpressionType& type);
inline bool IsTypeExpression(const ExpressionType& type);
inline bool IsUniformType(const ExpressionType& type);
inline bool IsVectorType(const ExpressionType& type);
}

View File

@@ -14,6 +14,17 @@ namespace Nz::ShaderAst
}
inline bool FunctionType::operator==(const FunctionType& rhs) const
{
return funcIndex == rhs.funcIndex;
}
inline bool FunctionType::operator!=(const FunctionType& rhs) const
{
return !operator==(rhs);
}
inline bool IdentifierType::operator==(const IdentifierType& rhs) const
{
return name == rhs.name;
@@ -25,6 +36,17 @@ namespace Nz::ShaderAst
}
inline bool IntrinsicFunctionType::operator==(const IntrinsicFunctionType& rhs) const
{
return intrinsic == rhs.intrinsic;
}
inline bool IntrinsicFunctionType::operator!=(const IntrinsicFunctionType& rhs) const
{
return !operator==(rhs);
}
inline bool MatrixType::operator==(const MatrixType& rhs) const
{
return columnCount == rhs.columnCount && rowCount == rhs.rowCount && type == rhs.type;
@@ -36,6 +58,12 @@ namespace Nz::ShaderAst
}
inline bool MethodType::operator!=(const MethodType& rhs) const
{
return !operator==(rhs);
}
inline bool NoType::operator==(const NoType& /*rhs*/) const
{
return true;
@@ -68,6 +96,18 @@ namespace Nz::ShaderAst
return !operator==(rhs);
}
inline bool Type::operator==(const Type& rhs) const
{
return typeIndex == rhs.typeIndex;
}
inline bool Type::operator!=(const Type& rhs) const
{
return !operator==(rhs);
}
inline bool UniformType::operator==(const UniformType& rhs) const
{
return containedType == rhs.containedType;
@@ -90,21 +130,36 @@ namespace Nz::ShaderAst
}
bool IsArrayType(const ExpressionType& type)
inline bool IsArrayType(const ExpressionType& type)
{
return std::holds_alternative<ArrayType>(type);
}
inline bool IsFunctionType(const ExpressionType& type)
{
return std::holds_alternative<FunctionType>(type);
}
inline bool IsIdentifierType(const ExpressionType& type)
{
return std::holds_alternative<IdentifierType>(type);
}
inline bool IsIntrinsicFunctionType(const ExpressionType& type)
{
return std::holds_alternative<IntrinsicFunctionType>(type);
}
inline bool IsMatrixType(const ExpressionType& type)
{
return std::holds_alternative<MatrixType>(type);
}
inline bool IsMethodType(const ExpressionType& type)
{
return std::holds_alternative<MethodType>(type);
}
inline bool IsNoType(const ExpressionType& type)
{
return std::holds_alternative<NoType>(type);
@@ -125,6 +180,11 @@ namespace Nz::ShaderAst
return std::holds_alternative<StructType>(type);
}
bool IsTypeExpression(const ExpressionType& type)
{
return std::holds_alternative<Type>(type);
}
bool IsUniformType(const ExpressionType& type)
{
return std::holds_alternative<UniformType>(type);

View File

@@ -107,7 +107,7 @@ namespace Nz::ShaderAst
NodeType GetType() const override;
void Visit(AstExpressionVisitor& visitor) override;
std::variant<std::string, std::size_t> targetFunction;
ExpressionPtr targetFunction;
std::vector<ExpressionPtr> parameters;
};
@@ -126,7 +126,7 @@ namespace Nz::ShaderAst
NodeType GetType() const override;
void Visit(AstExpressionVisitor& visitor) override;
ExpressionType targetType;
ExpressionValue<ExpressionType> targetType;
std::array<ExpressionPtr, 4> expressions;
};
@@ -249,10 +249,10 @@ namespace Nz::ShaderAst
NodeType GetType() const override;
void Visit(AstStatementVisitor& visitor) override;
ExpressionValue<ExpressionType> type;
std::optional<std::size_t> constIndex;
std::string name;
ExpressionPtr expression;
ExpressionType type;
};
struct NAZARA_SHADER_API DeclareExternalStatement : Statement
@@ -262,13 +262,13 @@ namespace Nz::ShaderAst
struct ExternalVar
{
AttributeValue<UInt32> bindingIndex;
AttributeValue<UInt32> bindingSet;
ExpressionValue<UInt32> bindingIndex;
ExpressionValue<UInt32> bindingSet;
ExpressionValue<ExpressionType> type;
std::string name;
ExpressionType type;
};
AttributeValue<UInt32> bindingSet;
ExpressionValue<UInt32> bindingSet;
std::optional<std::size_t> varIndex;
std::vector<ExternalVar> externalVars;
};
@@ -281,18 +281,18 @@ namespace Nz::ShaderAst
struct Parameter
{
std::string name;
ExpressionType type;
ExpressionValue<ExpressionType> type;
};
AttributeValue<DepthWriteMode> depthWrite;
AttributeValue<bool> earlyFragmentTests;
AttributeValue<ShaderStageType> entryStage;
ExpressionValue<DepthWriteMode> depthWrite;
ExpressionValue<bool> earlyFragmentTests;
ExpressionValue<ShaderStageType> entryStage;
ExpressionValue<ExpressionType> returnType;
std::optional<std::size_t> funcIndex;
std::optional<std::size_t> varIndex;
std::string name;
std::vector<Parameter> parameters;
std::vector<StatementPtr> statements;
ExpressionType returnType;
};
struct NAZARA_SHADER_API DeclareOptionStatement : Statement
@@ -303,7 +303,7 @@ namespace Nz::ShaderAst
std::optional<std::size_t> optIndex;
std::string optName;
ExpressionPtr defaultValue;
ExpressionType optType;
ExpressionValue<ExpressionType> optType;
};
struct NAZARA_SHADER_API DeclareStructStatement : Statement
@@ -323,7 +323,7 @@ namespace Nz::ShaderAst
std::optional<std::size_t> varIndex;
std::string varName;
ExpressionPtr initialExpression;
ExpressionType varType;
ExpressionValue<ExpressionType> varType;
};
struct NAZARA_SHADER_API DiscardStatement : Statement
@@ -345,7 +345,7 @@ namespace Nz::ShaderAst
NodeType GetType() const override;
void Visit(AstStatementVisitor& visitor) override;
AttributeValue<LoopUnroll> unroll;
ExpressionValue<LoopUnroll> unroll;
std::optional<std::size_t> varIndex;
std::string varName;
ExpressionPtr fromExpr;
@@ -359,7 +359,7 @@ namespace Nz::ShaderAst
NodeType GetType() const override;
void Visit(AstStatementVisitor& visitor) override;
AttributeValue<LoopUnroll> unroll;
ExpressionValue<LoopUnroll> unroll;
std::optional<std::size_t> varIndex;
std::string varName;
ExpressionPtr expression;
@@ -388,12 +388,20 @@ namespace Nz::ShaderAst
ExpressionPtr returnExpr;
};
struct NAZARA_SHADER_API ScopedStatement : Statement
{
NodeType GetType() const override;
void Visit(AstStatementVisitor& visitor) override;
StatementPtr statement;
};
struct NAZARA_SHADER_API WhileStatement : Statement
{
NodeType GetType() const override;
void Visit(AstStatementVisitor& visitor) override;
AttributeValue<LoopUnroll> unroll;
ExpressionValue<LoopUnroll> unroll;
ExpressionPtr condition;
StatementPtr body;
};

View File

@@ -11,6 +11,7 @@
#include <Nazara/Core/Bitset.hpp>
#include <Nazara/Shader/Config.hpp>
#include <Nazara/Shader/Ast/AstCloner.hpp>
#include <Nazara/Shader/Ast/AstTypes.hpp>
#include <unordered_map>
#include <unordered_set>
#include <vector>
@@ -19,6 +20,8 @@ namespace Nz::ShaderAst
{
class NAZARA_SHADER_API SanitizeVisitor final : AstCloner
{
friend class AstTypeExpressionVisitor;
public:
struct Options;
@@ -55,6 +58,7 @@ namespace Nz::ShaderAst
struct Identifier;
using AstCloner::CloneExpression;
ExpressionValue<ExpressionType> CloneType(const ExpressionValue<ExpressionType>& exprType) override;
ExpressionPtr Clone(AccessIdentifierExpression& node) override;
ExpressionPtr Clone(AccessIndexExpression& node) override;
@@ -84,50 +88,58 @@ namespace Nz::ShaderAst
StatementPtr Clone(ForStatement& node) override;
StatementPtr Clone(ForEachStatement& node) override;
StatementPtr Clone(MultiStatement& node) override;
StatementPtr Clone(ScopedStatement& node) override;
StatementPtr Clone(WhileStatement& node) override;
const Identifier* FindIdentifier(const std::string_view& identifierName) const;
template<typename F> const Identifier* FindIdentifier(const std::string_view& identifierName, F&& functor) const;
TypeParameter FindTypeParameter(const std::string_view& identifierName) const;
Expression& MandatoryExpr(const ExpressionPtr& node);
Statement& MandatoryStatement(const StatementPtr& node);
void TypeMustMatch(const ExpressionPtr& left, const ExpressionPtr& right);
void TypeMustMatch(const ExpressionType& left, const ExpressionType& right);
Expression& MandatoryExpr(const ExpressionPtr& node) const;
Statement& MandatoryStatement(const StatementPtr& node) const;
void PushScope();
void PopScope();
ExpressionPtr CacheResult(ExpressionPtr expression);
template<typename T> const T& ComputeAttributeValue(AttributeValue<T>& attribute);
ConstantValue ComputeConstantValue(Expression& expr);
template<typename T> std::unique_ptr<T> Optimize(T& node);
std::size_t DeclareFunction(DeclareFunctionStatement& funcDecl);
ConstantValue ComputeConstantValue(Expression& expr) const;
template<typename T> const T& ComputeExprValue(ExpressionValue<T>& attribute) const;
template<typename T> std::unique_ptr<T> Optimize(T& node) const;
void PropagateFunctionFlags(std::size_t funcIndex, FunctionFlags flags, Bitset<>& seen);
void RegisterBuiltin();
std::size_t RegisterConstant(std::string name, ConstantValue value);
FunctionData& RegisterFunction(std::size_t functionIndex);
std::size_t RegisterFunction(std::string name, FunctionData funcData);
std::size_t RegisterIntrinsic(std::string name, IntrinsicType type);
std::size_t RegisterStruct(std::string name, StructDescription* description);
std::size_t RegisterType(std::string name, ExpressionType expressionType);
std::size_t RegisterType(std::string name, PartialType partialType);
std::size_t RegisterVariable(std::string name, ExpressionType type);
void ResolveFunctions();
const ExpressionPtr& ResolveCondExpression(ConditionalExpression& node);
std::size_t ResolveStruct(const ExpressionType& exprType);
std::size_t ResolveStruct(const IdentifierType& identifierType);
std::size_t ResolveStruct(const StructType& structType);
std::size_t ResolveStruct(const UniformType& uniformType);
ExpressionType ResolveType(const ExpressionType& exprType);
ExpressionType ResolveType(const ExpressionValue<ExpressionType>& exprTypeValue);
void SanitizeIdentifier(std::string& identifier);
void TypeMustMatch(const ExpressionPtr& left, const ExpressionPtr& right) const;
void TypeMustMatch(const ExpressionType& left, const ExpressionType& right) const;
StatementPtr Unscope(StatementPtr node);
void Validate(WhileStatement& node);
void Validate(AccessIndexExpression& node);
void Validate(AssignExpression& node);
void Validate(BinaryExpression& node);
void Validate(CallFunctionExpression& node, const DeclareFunctionStatement* referenceDeclaration);
void Validate(CallFunctionExpression& node);
void Validate(CastExpression& node);
void Validate(DeclareVariableStatement& node);
void Validate(IntrinsicExpression& node);
@@ -141,7 +153,6 @@ namespace Nz::ShaderAst
Bitset<> calledByFunctions;
DeclareFunctionStatement* node;
FunctionFlags flags;
bool defined = false;
};
struct Identifier
@@ -153,6 +164,7 @@ namespace Nz::ShaderAst
Function,
Intrinsic,
Struct,
Type,
Variable
};