Shader: First working version on both Vulkan & OpenGL (ES)

This commit is contained in:
Jérôme Leclercq 2021-04-12 15:38:20 +02:00
parent f93a5bbdc1
commit ea99c6a19e
42 changed files with 1803 additions and 1053 deletions

View File

@ -140,6 +140,7 @@ namespace Nz
HLSL,
MSL,
NazaraBinary,
NazaraShader,
SpirV
};

View File

@ -43,7 +43,6 @@
#include <Nazara/Shader/ShaderAstUtils.hpp>
#include <Nazara/Shader/ShaderAstValidator.hpp>
#include <Nazara/Shader/ShaderBuilder.hpp>
#include <Nazara/Shader/ShaderConstantValue.hpp>
#include <Nazara/Shader/ShaderEnums.hpp>
#include <Nazara/Shader/ShaderLangLexer.hpp>
#include <Nazara/Shader/ShaderLangParser.hpp>

View File

@ -13,9 +13,9 @@
#include <Nazara/Math/Vector4.hpp>
#include <variant>
namespace Nz
namespace Nz::ShaderAst
{
using ShaderConstantValue = std::variant<
using ConstantValue = std::variant<
bool,
float,
Int32,

View File

@ -50,9 +50,17 @@ namespace Nz::ShaderAst
inline bool operator!=(const SamplerType& rhs) const;
};
struct StructType
{
std::size_t structIndex;
inline bool operator==(const StructType& rhs) const;
inline bool operator!=(const StructType& rhs) const;
};
struct UniformType
{
IdentifierType containedType;
std::variant<IdentifierType, StructType> containedType;
inline bool operator==(const UniformType& rhs) const;
inline bool operator!=(const UniformType& rhs) const;
@ -67,7 +75,7 @@ namespace Nz::ShaderAst
inline bool operator!=(const VectorType& rhs) const;
};
using ExpressionType = std::variant<NoType, IdentifierType, PrimitiveType, MatrixType, SamplerType, UniformType, VectorType>;
using ExpressionType = std::variant<NoType, IdentifierType, PrimitiveType, MatrixType, SamplerType, StructType, UniformType, VectorType>;
struct StructDescription
{

View File

@ -51,6 +51,17 @@ namespace Nz::ShaderAst
return !operator==(rhs);
}
inline bool StructType::operator==(const StructType& rhs) const
{
return structIndex == rhs.structIndex;
}
inline bool StructType::operator!=(const StructType& rhs) const
{
return !operator==(rhs);
}
inline bool UniformType::operator==(const UniformType& rhs) const
{
return containedType == rhs.containedType;
@ -61,6 +72,7 @@ namespace Nz::ShaderAst
return !operator==(rhs);
}
inline bool VectorType::operator==(const VectorType& rhs) const
{
return componentCount == rhs.componentCount && type == rhs.type;

View File

@ -0,0 +1,90 @@
// Copyright (C) 2020 Jérôme Leclercq
// This file is part of the "Nazara Engine - Shader generator"
// For conditions of distribution and use, see copyright notice in Config.hpp
#pragma once
#ifndef NAZARA_SHADERAST_TRANSFORMVISITOR_HPP
#define NAZARA_SHADERAST_TRANSFORMVISITOR_HPP
#include <Nazara/Prerequisites.hpp>
#include <Nazara/Shader/Config.hpp>
#include <Nazara/Shader/ShaderAstCloner.hpp>
#include <vector>
namespace Nz::ShaderAst
{
class NAZARA_SHADER_API TransformVisitor : AstCloner
{
public:
inline TransformVisitor();
TransformVisitor(const TransformVisitor&) = delete;
TransformVisitor(TransformVisitor&&) = delete;
~TransformVisitor() = default;
StatementPtr Transform(StatementPtr& statement);
TransformVisitor& operator=(const TransformVisitor&) = delete;
TransformVisitor& operator=(TransformVisitor&&) = delete;
private:
struct Identifier;
ExpressionPtr Clone(AccessMemberIdentifierExpression& node) override;
ExpressionPtr Clone(CastExpression& node) override;
ExpressionPtr Clone(IdentifierExpression& node) override;
ExpressionPtr CloneExpression(ExpressionPtr& expr) override;
inline const Identifier* FindIdentifier(const std::string_view& identifierName) const;
void PushScope();
void PopScope();
inline std::size_t RegisterFunction(std::string name);
inline std::size_t RegisterStruct(std::string name, StructDescription description);
inline std::size_t RegisterVariable(std::string name);
ExpressionType ResolveType(const ExpressionType& exprType);
using AstCloner::Visit;
void Visit(BranchStatement& node) override;
void Visit(ConditionalStatement& node) override;
void Visit(DeclareExternalStatement& node) override;
void Visit(DeclareFunctionStatement& node) override;
void Visit(DeclareStructStatement& node) override;
void Visit(DeclareVariableStatement& node) override;
void Visit(MultiStatement& node) override;
struct Alias
{
std::variant<ExpressionType> value;
};
struct Struct
{
std::size_t structIndex;
};
struct Variable
{
std::size_t varIndex;
};
struct Identifier
{
std::string name;
std::variant<Alias, Struct, Variable> value;
};
private:
std::size_t m_nextFuncIndex;
std::size_t m_nextVarIndex;
std::vector<Identifier> m_identifiersInScope;
std::vector<StructDescription> m_structs;
std::vector<std::size_t> m_scopeSizes;
};
}
#include <Nazara/Shader/Ast/TransformVisitor.inl>
#endif

View File

@ -0,0 +1,62 @@
// Copyright (C) 2020 Jérôme Leclercq
// This file is part of the "Nazara Engine - Shader generator"
// For conditions of distribution and use, see copyright notice in Config.hpp
#include <Nazara/Shader/Ast/TransformVisitor.hpp>
#include <Nazara/Shader/Debug.hpp>
namespace Nz::ShaderAst
{
inline TransformVisitor::TransformVisitor() :
m_nextFuncIndex(0),
m_nextVarIndex(0)
{
}
inline auto TransformVisitor::FindIdentifier(const std::string_view& identifierName) const -> const Identifier*
{
auto it = std::find_if(m_identifiersInScope.rbegin(), m_identifiersInScope.rend(), [&](const Identifier& identifier) { return identifier.name == identifierName; });
if (it == m_identifiersInScope.rend())
return nullptr;
return &*it;
}
inline std::size_t TransformVisitor::RegisterFunction(std::string name)
{
std::size_t funcIndex = m_nextFuncIndex++;
return funcIndex;
}
inline std::size_t TransformVisitor::RegisterStruct(std::string name, StructDescription description)
{
std::size_t structIndex = m_structs.size();
m_structs.emplace_back(std::move(description));
m_identifiersInScope.push_back({
std::move(name),
Struct {
structIndex
}
});
return structIndex;
}
inline std::size_t TransformVisitor::RegisterVariable(std::string name)
{
std::size_t varIndex = m_nextVarIndex++;
m_identifiersInScope.push_back({
std::move(name),
Variable {
varIndex
}
});
return varIndex;
}
}
#include <Nazara/Shader/DebugOff.hpp>

View File

@ -52,6 +52,7 @@ namespace Nz
void Append(ShaderAst::NoType);
void Append(ShaderAst::PrimitiveType type);
void Append(const ShaderAst::SamplerType& samplerType);
void Append(const ShaderAst::StructType& structType);
void Append(const ShaderAst::UniformType& uniformType);
void Append(const ShaderAst::VectorType& vecType);
template<typename T> void Append(const T& param);
@ -67,7 +68,7 @@ namespace Nz
void Visit(ShaderAst::ExpressionPtr& expr, bool encloseIfRequired = false);
void Visit(ShaderAst::AccessMemberExpression& node) override;
void Visit(ShaderAst::AccessMemberIdentifierExpression& node) override;
void Visit(ShaderAst::AssignExpression& node) override;
void Visit(ShaderAst::BinaryExpression& node) override;
void Visit(ShaderAst::CastExpression& node) override;

View File

@ -30,15 +30,25 @@ namespace Nz::ShaderAst
AstCloner& operator=(AstCloner&&) = delete;
protected:
ExpressionPtr CloneExpression(ExpressionPtr& expr);
StatementPtr CloneStatement(StatementPtr& statement);
virtual ExpressionPtr CloneExpression(ExpressionPtr& expr);
virtual StatementPtr CloneStatement(StatementPtr& statement);
virtual StatementPtr Clone(DeclareExternalStatement& node);
virtual StatementPtr Clone(DeclareFunctionStatement& node);
virtual StatementPtr Clone(DeclareStructStatement& node);
virtual StatementPtr Clone(DeclareVariableStatement& node);
virtual ExpressionPtr Clone(AccessMemberIdentifierExpression& node);
virtual ExpressionPtr Clone(AccessMemberIndexExpression& node);
virtual ExpressionPtr Clone(CastExpression& node);
virtual ExpressionPtr Clone(IdentifierExpression& node);
virtual ExpressionPtr Clone(VariableExpression& node);
using AstExpressionVisitor::Visit;
using AstStatementVisitor::Visit;
void Visit(AccessMemberExpression& node) override;
void Visit(AccessMemberIdentifierExpression& node) override;
void Visit(AccessMemberIndexExpression& node) override;
void Visit(AssignExpression& node) override;
void Visit(BinaryExpression& node) override;
void Visit(CastExpression& node) override;
@ -47,6 +57,7 @@ namespace Nz::ShaderAst
void Visit(IdentifierExpression& node) override;
void Visit(IntrinsicExpression& node) override;
void Visit(SwizzleExpression& node) override;
void Visit(VariableExpression& node) override;
void Visit(BranchStatement& node) override;
void Visit(ConditionalStatement& node) override;

View File

@ -26,7 +26,8 @@
#define NAZARA_SHADERAST_STATEMENT_LAST(X) NAZARA_SHADERAST_STATEMENT(X)
#endif
NAZARA_SHADERAST_EXPRESSION(AccessMemberExpression)
NAZARA_SHADERAST_EXPRESSION(AccessMemberIdentifierExpression)
NAZARA_SHADERAST_EXPRESSION(AccessMemberIndexExpression)
NAZARA_SHADERAST_EXPRESSION(AssignExpression)
NAZARA_SHADERAST_EXPRESSION(BinaryExpression)
NAZARA_SHADERAST_EXPRESSION(CastExpression)
@ -35,6 +36,7 @@ NAZARA_SHADERAST_EXPRESSION(ConstantExpression)
NAZARA_SHADERAST_EXPRESSION(IdentifierExpression)
NAZARA_SHADERAST_EXPRESSION(IntrinsicExpression)
NAZARA_SHADERAST_EXPRESSION(SwizzleExpression)
NAZARA_SHADERAST_EXPRESSION(VariableExpression)
NAZARA_SHADERAST_STATEMENT(BranchStatement)
NAZARA_SHADERAST_STATEMENT(ConditionalStatement)
NAZARA_SHADERAST_STATEMENT(DeclareExternalStatement)

View File

@ -20,7 +20,8 @@ namespace Nz::ShaderAst
AstRecursiveVisitor() = default;
~AstRecursiveVisitor() = default;
void Visit(AccessMemberExpression& node) override;
void Visit(AccessMemberIdentifierExpression& node) override;
void Visit(AccessMemberIndexExpression& node) override;
void Visit(AssignExpression& node) override;
void Visit(BinaryExpression& node) override;
void Visit(CastExpression& node) override;
@ -29,6 +30,7 @@ namespace Nz::ShaderAst
void Visit(IdentifierExpression& node) override;
void Visit(IntrinsicExpression& node) override;
void Visit(SwizzleExpression& node) override;
void Visit(VariableExpression& node) override;
void Visit(BranchStatement& node) override;
void Visit(ConditionalStatement& node) override;

View File

@ -23,7 +23,8 @@ namespace Nz::ShaderAst
AstSerializerBase(AstSerializerBase&&) = delete;
~AstSerializerBase() = default;
void Serialize(AccessMemberExpression& node);
void Serialize(AccessMemberIdentifierExpression& node);
void Serialize(AccessMemberIndexExpression& node);
void Serialize(AssignExpression& node);
void Serialize(BinaryExpression& node);
void Serialize(CastExpression& node);
@ -32,6 +33,7 @@ namespace Nz::ShaderAst
void Serialize(IdentifierExpression& node);
void Serialize(IntrinsicExpression& node);
void Serialize(SwizzleExpression& node);
void Serialize(VariableExpression& node);
void Serialize(BranchStatement& node);
void Serialize(ConditionalStatement& node);

View File

@ -31,7 +31,8 @@ namespace Nz::ShaderAst
private:
using AstExpressionVisitor::Visit;
void Visit(AccessMemberExpression& node) override;
void Visit(AccessMemberIdentifierExpression& node) override;
void Visit(AccessMemberIndexExpression& node) override;
void Visit(AssignExpression& node) override;
void Visit(BinaryExpression& node) override;
void Visit(CastExpression& node) override;
@ -40,6 +41,7 @@ namespace Nz::ShaderAst
void Visit(IdentifierExpression& node) override;
void Visit(IntrinsicExpression& node) override;
void Visit(SwizzleExpression& node) override;
void Visit(VariableExpression& node) override;
ExpressionCategory m_expressionCategory;
};

View File

@ -34,7 +34,7 @@ namespace Nz::ShaderAst
ExpressionType CheckField(const std::string& structName, const std::string* memberIdentifier, std::size_t remainingMembers);
const ExpressionType& ResolveAlias(const ExpressionType& expressionType);
void Visit(AccessMemberExpression& node) override;
void Visit(AccessMemberIdentifierExpression& node) override;
void Visit(AssignExpression& node) override;
void Visit(BinaryExpression& node) override;
void Visit(CastExpression& node) override;

View File

@ -15,6 +15,11 @@ namespace Nz::ShaderBuilder
{
namespace Impl
{
struct AccessMember
{
inline std::unique_ptr<ShaderAst::AccessMemberIdentifierExpression> operator()(ShaderAst::ExpressionPtr structExpr, std::vector<std::string> memberIdentifiers) const;
};
struct Assign
{
inline std::unique_ptr<ShaderAst::AssignExpression> operator()(ShaderAst::AssignType op, ShaderAst::ExpressionPtr left, ShaderAst::ExpressionPtr right) const;
@ -36,9 +41,19 @@ namespace Nz::ShaderBuilder
inline std::unique_ptr<ShaderAst::CastExpression> operator()(ShaderAst::ExpressionType targetType, std::vector<ShaderAst::ExpressionPtr> expressions) const;
};
struct ConditionalExpression
{
inline std::unique_ptr<ShaderAst::ConditionalExpression> operator()(std::string conditionName, ShaderAst::ExpressionPtr truePath, ShaderAst::ExpressionPtr falsePath) const;
};
struct ConditionalStatement
{
inline std::unique_ptr<ShaderAst::ConditionalStatement> operator()(std::string conditionName, ShaderAst::StatementPtr statement) const;
};
struct Constant
{
inline std::unique_ptr<ShaderAst::ConstantExpression> operator()(ShaderConstantValue value) const;
inline std::unique_ptr<ShaderAst::ConstantExpression> operator()(ShaderAst::ConstantValue value) const;
};
struct DeclareFunction
@ -83,12 +98,20 @@ namespace Nz::ShaderBuilder
{
inline std::unique_ptr<ShaderAst::ReturnStatement> operator()(ShaderAst::ExpressionPtr expr = nullptr) const;
};
struct Swizzle
{
inline std::unique_ptr<ShaderAst::SwizzleExpression> operator()(ShaderAst::ExpressionPtr expression, std::vector<ShaderAst::SwizzleComponent> swizzleComponents) const;
};
}
constexpr Impl::AccessMember AccessMember;
constexpr Impl::Assign Assign;
constexpr Impl::Binary Binary;
constexpr Impl::Branch Branch;
constexpr Impl::Cast Cast;
constexpr Impl::ConditionalExpression ConditionalExpression;
constexpr Impl::ConditionalStatement ConditionalStatement;
constexpr Impl::Constant Constant;
constexpr Impl::DeclareFunction DeclareFunction;
constexpr Impl::DeclareStruct DeclareStruct;
@ -99,6 +122,7 @@ namespace Nz::ShaderBuilder
constexpr Impl::Intrinsic Intrinsic;
constexpr Impl::NoParam<ShaderAst::NoOpStatement> NoOp;
constexpr Impl::Return Return;
constexpr Impl::Swizzle Swizzle;
}
#include <Nazara/Shader/ShaderBuilder.inl>

View File

@ -7,6 +7,15 @@
namespace Nz::ShaderBuilder
{
inline std::unique_ptr<ShaderAst::AccessMemberIdentifierExpression> Impl::AccessMember::operator()(ShaderAst::ExpressionPtr structExpr, std::vector<std::string> memberIdentifiers) const
{
auto accessMemberNode = std::make_unique<ShaderAst::AccessMemberIdentifierExpression>();
accessMemberNode->structExpr = std::move(structExpr);
accessMemberNode->memberIdentifiers = std::move(memberIdentifiers);
return accessMemberNode;
}
inline std::unique_ptr<ShaderAst::AssignExpression> Impl::Assign::operator()(ShaderAst::AssignType op, ShaderAst::ExpressionPtr left, ShaderAst::ExpressionPtr right) const
{
auto assignNode = std::make_unique<ShaderAst::AssignExpression>();
@ -61,7 +70,26 @@ namespace Nz::ShaderBuilder
return castNode;
}
inline std::unique_ptr<ShaderAst::ConstantExpression> Impl::Constant::operator()(ShaderConstantValue value) const
inline std::unique_ptr<ShaderAst::ConditionalExpression> Impl::ConditionalExpression::operator()(std::string conditionName, ShaderAst::ExpressionPtr truePath, ShaderAst::ExpressionPtr falsePath) const
{
auto condExprNode = std::make_unique<ShaderAst::ConditionalExpression>();
condExprNode->conditionName = std::move(conditionName);
condExprNode->falsePath = std::move(falsePath);
condExprNode->truePath = std::move(truePath);
return condExprNode;
}
inline std::unique_ptr<ShaderAst::ConditionalStatement> Impl::ConditionalStatement::operator()(std::string conditionName, ShaderAst::StatementPtr statement) const
{
auto condStatementNode = std::make_unique<ShaderAst::ConditionalStatement>();
condStatementNode->conditionName = std::move(conditionName);
condStatementNode->statement = std::move(statement);
return condStatementNode;
}
inline std::unique_ptr<ShaderAst::ConstantExpression> Impl::Constant::operator()(ShaderAst::ConstantValue value) const
{
auto constantNode = std::make_unique<ShaderAst::ConstantExpression>();
constantNode->value = std::move(value);
@ -157,6 +185,18 @@ namespace Nz::ShaderBuilder
{
return std::make_unique<T>();
}
inline std::unique_ptr<ShaderAst::SwizzleExpression> Impl::Swizzle::operator()(ShaderAst::ExpressionPtr expression, std::vector<ShaderAst::SwizzleComponent> swizzleComponents) const
{
auto swizzleNode = std::make_unique<ShaderAst::SwizzleExpression>();
swizzleNode->expression = std::move(expression);
assert(swizzleComponents.size() <= swizzleNode->components.size());
for (std::size_t i = 0; i < swizzleComponents.size(); ++i)
swizzleNode->components[i] = swizzleComponents[i];
return swizzleNode;
}
}
#include <Nazara/Shader/DebugOff.hpp>

View File

@ -12,7 +12,7 @@
#include <Nazara/Math/Vector3.hpp>
#include <Nazara/Math/Vector4.hpp>
#include <Nazara/Shader/Config.hpp>
#include <Nazara/Shader/ShaderConstantValue.hpp>
#include <Nazara/Shader/Ast/ConstantValue.hpp>
#include <Nazara/Shader/ShaderEnums.hpp>
#include <Nazara/Shader/Ast/Attribute.hpp>
#include <Nazara/Shader/Ast/ExpressionType.hpp>
@ -60,7 +60,7 @@ namespace Nz::ShaderAst
std::optional<ExpressionType> cachedExpressionType;
};
struct NAZARA_SHADER_API AccessMemberExpression : public Expression
struct NAZARA_SHADER_API AccessMemberIdentifierExpression : public Expression
{
NodeType GetType() const override;
void Visit(AstExpressionVisitor& visitor) override;
@ -69,6 +69,15 @@ namespace Nz::ShaderAst
std::vector<std::string> memberIdentifiers;
};
struct NAZARA_SHADER_API AccessMemberIndexExpression : public Expression
{
NodeType GetType() const override;
void Visit(AstExpressionVisitor& visitor) override;
ExpressionPtr structExpr;
std::vector<std::size_t> memberIndices;
};
struct NAZARA_SHADER_API AssignExpression : public Expression
{
NodeType GetType() const override;
@ -113,7 +122,7 @@ namespace Nz::ShaderAst
NodeType GetType() const override;
void Visit(AstExpressionVisitor& visitor) override;
ShaderConstantValue value;
ShaderAst::ConstantValue value;
};
struct NAZARA_SHADER_API IdentifierExpression : public Expression
@ -143,6 +152,14 @@ namespace Nz::ShaderAst
ExpressionPtr expression;
};
struct NAZARA_SHADER_API VariableExpression : Expression
{
NodeType GetType() const override;
void Visit(AstExpressionVisitor& visitor) override;
std::size_t variableId;
};
// Statements
struct Statement;
@ -193,11 +210,12 @@ namespace Nz::ShaderAst
struct ExternalVar
{
std::vector<Attribute> attributes;
std::string name;
std::vector<Attribute> attributes;
ExpressionType type;
};
std::optional<std::size_t> varIndex;
std::vector<Attribute> attributes;
std::vector<ExternalVar> externalVars;
};
@ -213,6 +231,8 @@ namespace Nz::ShaderAst
ExpressionType type;
};
std::optional<std::size_t> funcIndex;
std::optional<std::size_t> varIndex;
std::string name;
std::vector<Attribute> attributes;
std::vector<Parameter> parameters;
@ -225,6 +245,7 @@ namespace Nz::ShaderAst
NodeType GetType() const override;
void Visit(AstStatementVisitor& visitor) override;
std::optional<std::size_t> structIndex;
std::vector<Attribute> attributes;
StructDescription description;
};
@ -234,6 +255,7 @@ namespace Nz::ShaderAst
NodeType GetType() const override;
void Visit(AstStatementVisitor& visitor) override;
std::optional<std::size_t> varIndex;
std::string varName;
ExpressionPtr initialExpression;
ExpressionType varType;
@ -274,6 +296,8 @@ namespace Nz::ShaderAst
ExpressionPtr returnExpr;
};
inline const ShaderAst::ExpressionType& GetExpressionType(ShaderAst::Expression& expr);
}
#include <Nazara/Shader/ShaderNodes.inl>

View File

@ -7,6 +7,11 @@
namespace Nz::ShaderAst
{
const ShaderAst::ExpressionType& GetExpressionType(ShaderAst::Expression& expr)
{
assert(expr.cachedExpressionType);
return expr.cachedExpressionType.value();
}
}
#include <Nazara/Shader/DebugOff.hpp>

View File

@ -11,7 +11,9 @@
#include <Nazara/Shader/Config.hpp>
#include <Nazara/Shader/ShaderAstExpressionVisitorExcept.hpp>
#include <Nazara/Shader/ShaderAstStatementVisitorExcept.hpp>
#include <Nazara/Shader/Ast/ExpressionType.hpp>
#include <Nazara/Shader/SpirvBlock.hpp>
#include <unordered_map>
#include <vector>
namespace Nz
@ -21,17 +23,25 @@ namespace Nz
class NAZARA_SHADER_API SpirvAstVisitor : public ShaderAst::ExpressionVisitorExcept, public ShaderAst::StatementVisitorExcept
{
public:
inline SpirvAstVisitor(SpirvWriter& writer, std::vector<SpirvBlock>& blocks);
struct EntryPoint;
struct FuncData;
struct Variable;
inline SpirvAstVisitor(SpirvWriter& writer, SpirvSection& instructions, std::vector<FuncData>& funcData);
SpirvAstVisitor(const SpirvAstVisitor&) = delete;
SpirvAstVisitor(SpirvAstVisitor&&) = delete;
~SpirvAstVisitor() = default;
UInt32 AllocateResultId();
UInt32 EvaluateExpression(ShaderAst::ExpressionPtr& expr);
const Variable& GetVariable(std::size_t varIndex) const;
using ExpressionVisitorExcept::Visit;
using StatementVisitorExcept::Visit;
void Visit(ShaderAst::AccessMemberExpression& node) override;
void Visit(ShaderAst::AccessMemberIndexExpression& node) override;
void Visit(ShaderAst::AssignExpression& node) override;
void Visit(ShaderAst::BinaryExpression& node) override;
void Visit(ShaderAst::BranchStatement& node) override;
@ -39,27 +49,102 @@ namespace Nz
void Visit(ShaderAst::ConditionalExpression& node) override;
void Visit(ShaderAst::ConditionalStatement& node) override;
void Visit(ShaderAst::ConstantExpression& node) override;
void Visit(ShaderAst::DeclareExternalStatement& node) override;
void Visit(ShaderAst::DeclareFunctionStatement& node) override;
void Visit(ShaderAst::DeclareStructStatement& node) override;
void Visit(ShaderAst::DeclareVariableStatement& node) override;
void Visit(ShaderAst::DiscardStatement& node) override;
void Visit(ShaderAst::ExpressionStatement& node) override;
void Visit(ShaderAst::IdentifierExpression& node) override;
void Visit(ShaderAst::IntrinsicExpression& node) override;
void Visit(ShaderAst::MultiStatement& node) override;
void Visit(ShaderAst::NoOpStatement& node) override;
void Visit(ShaderAst::ReturnStatement& node) override;
void Visit(ShaderAst::SwizzleExpression& node) override;
void Visit(ShaderAst::VariableExpression& node) override;
SpirvAstVisitor& operator=(const SpirvAstVisitor&) = delete;
SpirvAstVisitor& operator=(SpirvAstVisitor&&) = delete;
struct EntryPoint
{
struct Input
{
UInt32 memberIndexConstantId;
UInt32 memberPointerId;
UInt32 varId;
};
struct Output
{
Int32 memberIndex;
UInt32 typeId;
UInt32 varId;
};
struct InputStruct
{
UInt32 pointerId;
UInt32 typeId;
};
ShaderStageType stageType;
std::optional<InputStruct> inputStruct;
std::optional<UInt32> outputStructTypeId;
std::size_t funcIndex;
std::vector<Input> inputs;
std::vector<Output> outputs;
};
struct FuncData
{
std::optional<EntryPoint> entryPointData;
struct Parameter
{
UInt32 pointerTypeId;
UInt32 typeId;
};
struct Variable
{
UInt32 typeId;
UInt32 varId;
};
std::string name;
std::vector<Parameter> parameters;
std::vector<Variable> variables;
std::unordered_map<std::size_t, std::size_t> varIndexToVarId;
UInt32 funcId;
UInt32 funcTypeId;
UInt32 returnTypeId;
};
struct Variable
{
SpirvStorageClass storage;
UInt32 pointerId;
UInt32 pointedTypeId;
};
private:
inline const ShaderAst::ExpressionType& GetExpressionType(ShaderAst::Expression& expr) const;
void PushResultId(UInt32 value);
UInt32 PopResultId();
std::vector<SpirvBlock>& m_blocks;
inline void RegisterExternalVariable(std::size_t varIndex, const ShaderAst::ExpressionType& type);
inline void RegisterStruct(std::size_t structIndex, ShaderAst::StructDescription structDesc);
inline void RegisterVariable(std::size_t varIndex, UInt32 typeId, UInt32 pointerId, SpirvStorageClass storageClass);
std::size_t m_extVarIndex;
std::size_t m_funcIndex;
std::vector<std::size_t> m_scopeSizes;
std::vector<FuncData>& m_funcData;
std::vector<ShaderAst::StructDescription> m_structs;
std::vector<std::optional<Variable>> m_variables;
std::vector<SpirvBlock> m_functionBlocks;
std::vector<UInt32> m_resultIds;
SpirvBlock* m_currentBlock;
SpirvSection& m_instructions;
SpirvWriter& m_writer;
};
}

View File

@ -7,17 +7,42 @@
namespace Nz
{
inline SpirvAstVisitor::SpirvAstVisitor(SpirvWriter& writer, std::vector<SpirvBlock>& blocks) :
m_blocks(blocks),
inline SpirvAstVisitor::SpirvAstVisitor(SpirvWriter& writer, SpirvSection& instructions, std::vector<FuncData>& funcData) :
m_extVarIndex(0),
m_funcIndex(0),
m_funcData(funcData),
m_currentBlock(nullptr),
m_instructions(instructions),
m_writer(writer)
{
m_currentBlock = &m_blocks.back();
}
inline const ShaderAst::ExpressionType& SpirvAstVisitor::GetExpressionType(ShaderAst::Expression& expr) const
void SpirvAstVisitor::RegisterExternalVariable(std::size_t varIndex, const ShaderAst::ExpressionType& type)
{
assert(expr.cachedExpressionType);
return expr.cachedExpressionType.value();
UInt32 pointerId = m_writer.GetExtVarPointerId(varIndex);
SpirvStorageClass storageClass = (IsSamplerType(type)) ? SpirvStorageClass::UniformConstant : SpirvStorageClass::Uniform;
RegisterVariable(varIndex, m_writer.GetTypeId(type), pointerId, storageClass);
}
inline void SpirvAstVisitor::RegisterStruct(std::size_t structIndex, ShaderAst::StructDescription structDesc)
{
if (structIndex >= m_structs.size())
m_structs.resize(structIndex + 1);
m_structs[structIndex] = std::move(structDesc);
}
inline void SpirvAstVisitor::RegisterVariable(std::size_t varIndex, UInt32 typeId, UInt32 pointerId, SpirvStorageClass storageClass)
{
if (varIndex >= m_variables.size())
m_variables.resize(varIndex + 1);
m_variables[varIndex] = Variable{
storageClass,
pointerId,
typeId
};
}
}

View File

@ -8,8 +8,8 @@
#define NAZARA_SPIRVCONSTANTCACHE_HPP
#include <Nazara/Prerequisites.hpp>
#include <Nazara/Shader/ShaderConstantValue.hpp>
#include <Nazara/Shader/ShaderEnums.hpp>
#include <Nazara/Shader/Ast/ConstantValue.hpp>
#include <Nazara/Shader/Ast/ExpressionType.hpp>
#include <Nazara/Shader/SpirvData.hpp>
#include <memory>
@ -25,6 +25,8 @@ namespace Nz
class NAZARA_SHADER_API SpirvConstantCache
{
public:
using StructCallback = std::function<const ShaderAst::StructDescription&(std::size_t structIndex)>;
SpirvConstantCache(UInt32& resultId);
SpirvConstantCache(const SpirvConstantCache& cache) = delete;
SpirvConstantCache(SpirvConstantCache&& cache) noexcept;
@ -37,8 +39,6 @@ namespace Nz
using ConstantPtr = std::shared_ptr<Constant>;
using TypePtr = std::shared_ptr<Type>;
using IdentifierCallback = std::function<TypePtr(const std::string& identifier)>;
struct Bool {};
struct Float
@ -66,11 +66,6 @@ namespace Nz
UInt32 columnCount;
};
struct Identifier
{
std::string name;
};
struct Image
{
std::optional<SpirvAccessQualifier> qualifier;
@ -112,7 +107,7 @@ namespace Nz
std::vector<Member> members;
};
using AnyType = std::variant<Bool, Float, Function, Identifier, Image, Integer, Matrix, Pointer, SampledImage, Structure, Vector, Void>;
using AnyType = std::variant<Bool, Float, Function, Image, Integer, Matrix, Pointer, SampledImage, Structure, Vector, Void>;
struct ConstantBool
{
@ -134,10 +129,11 @@ namespace Nz
struct Variable
{
std::optional<std::size_t> funcId; //< For inputs/outputs
std::optional<ConstantPtr> initializer;
std::string debugName;
TypePtr type;
SpirvStorageClass storageClass;
std::optional<ConstantPtr> initializer;
};
using BaseType = std::variant<Bool, Float, Integer, Vector, Matrix, Image>;
@ -166,6 +162,21 @@ namespace Nz
AnyType type;
};
ConstantPtr BuildConstant(const ShaderAst::ConstantValue& value) const;
TypePtr BuildFunctionType(const ShaderAst::ExpressionType& retType, const std::vector<ShaderAst::ExpressionType>& parameters) const;
TypePtr BuildPointerType(const ShaderAst::PrimitiveType& type, SpirvStorageClass storageClass) const;
TypePtr BuildPointerType(const ShaderAst::ExpressionType& type, SpirvStorageClass storageClass) const;
TypePtr BuildType(const ShaderAst::ExpressionType& type) const;
TypePtr BuildType(const ShaderAst::IdentifierType& type) const;
TypePtr BuildType(const ShaderAst::MatrixType& type) const;
TypePtr BuildType(const ShaderAst::NoType& type) const;
TypePtr BuildType(const ShaderAst::PrimitiveType& type) const;
TypePtr BuildType(const ShaderAst::SamplerType& type) const;
TypePtr BuildType(const ShaderAst::StructType& type) const;
TypePtr BuildType(const ShaderAst::StructDescription& structDesc) const;
TypePtr BuildType(const ShaderAst::VectorType& type) const;
TypePtr BuildType(const ShaderAst::UniformType& type) const;
UInt32 GetId(const Constant& c);
UInt32 GetId(const Type& t);
UInt32 GetId(const Variable& v);
@ -174,26 +185,13 @@ namespace Nz
UInt32 Register(Type t);
UInt32 Register(Variable v);
void SetIdentifierCallback(IdentifierCallback callback);
void SetStructCallback(StructCallback callback);
void Write(SpirvSection& annotations, SpirvSection& constants, SpirvSection& debugInfos);
SpirvConstantCache& operator=(const SpirvConstantCache& cache) = delete;
SpirvConstantCache& operator=(SpirvConstantCache&& cache) noexcept;
static ConstantPtr BuildConstant(const ShaderConstantValue& value);
static TypePtr BuildFunctionType(const ShaderAst::ExpressionType& retType, const std::vector<ShaderAst::ExpressionType>& parameters);
static TypePtr BuildPointerType(const ShaderAst::PrimitiveType& type, SpirvStorageClass storageClass);
static TypePtr BuildPointerType(const ShaderAst::ExpressionType& type, SpirvStorageClass storageClass);
static TypePtr BuildType(const ShaderAst::ExpressionType& type);
static TypePtr BuildType(const ShaderAst::IdentifierType& type);
static TypePtr BuildType(const ShaderAst::MatrixType& type);
static TypePtr BuildType(const ShaderAst::NoType& type);
static TypePtr BuildType(const ShaderAst::PrimitiveType& type);
static TypePtr BuildType(const ShaderAst::SamplerType& type);
static TypePtr BuildType(const ShaderAst::StructDescription& structDesc);
static TypePtr BuildType(const ShaderAst::VectorType& type);
private:
struct DepRegisterer;
struct Eq;
@ -204,7 +202,6 @@ namespace Nz
void WriteStruct(const Structure& structData, UInt32 resultId, SpirvSection& annotations, SpirvSection& constants, SpirvSection& debugInfos);
IdentifierCallback m_identifierCallback;
std::unique_ptr<Internal> m_internal;
};
}

View File

@ -15,13 +15,14 @@
namespace Nz
{
class SpirvAstVisitor;
class SpirvBlock;
class SpirvWriter;
class NAZARA_SHADER_API SpirvExpressionLoad : public ShaderAst::ExpressionVisitorExcept
{
public:
inline SpirvExpressionLoad(SpirvWriter& writer, SpirvBlock& block);
inline SpirvExpressionLoad(SpirvWriter& writer, SpirvAstVisitor& visitor, SpirvBlock& block);
SpirvExpressionLoad(const SpirvExpressionLoad&) = delete;
SpirvExpressionLoad(SpirvExpressionLoad&&) = delete;
~SpirvExpressionLoad() = default;
@ -29,8 +30,8 @@ namespace Nz
UInt32 Evaluate(ShaderAst::Expression& node);
using ExpressionVisitorExcept::Visit;
//void Visit(ShaderAst::AccessMemberExpression& node) override;
void Visit(ShaderAst::IdentifierExpression& node) override;
void Visit(ShaderAst::AccessMemberIndexExpression& node) override;
void Visit(ShaderAst::VariableExpression& node) override;
SpirvExpressionLoad& operator=(const SpirvExpressionLoad&) = delete;
SpirvExpressionLoad& operator=(SpirvExpressionLoad&&) = delete;
@ -39,7 +40,7 @@ namespace Nz
struct Pointer
{
SpirvStorageClass storage;
UInt32 resultId;
UInt32 pointerId;
UInt32 pointedTypeId;
};
@ -48,6 +49,7 @@ namespace Nz
UInt32 resultId;
};
SpirvAstVisitor& m_visitor;
SpirvBlock& m_block;
SpirvWriter& m_writer;
std::variant<std::monostate, Pointer, Value> m_value;

View File

@ -7,9 +7,10 @@
namespace Nz
{
inline SpirvExpressionLoad::SpirvExpressionLoad(SpirvWriter& writer, SpirvBlock& block) :
inline SpirvExpressionLoad::SpirvExpressionLoad(SpirvWriter& writer, SpirvAstVisitor& visitor, SpirvBlock& block) :
m_block(block),
m_writer(writer)
m_writer(writer),
m_visitor(visitor)
{
}
}

View File

@ -14,13 +14,14 @@
namespace Nz
{
class SpirvAstVisitor;
class SpirvBlock;
class SpirvWriter;
class NAZARA_SHADER_API SpirvExpressionStore : public ShaderAst::ExpressionVisitorExcept
{
public:
inline SpirvExpressionStore(SpirvWriter& writer, SpirvBlock& block);
inline SpirvExpressionStore(SpirvWriter& writer, SpirvAstVisitor& visitor, SpirvBlock& block);
SpirvExpressionStore(const SpirvExpressionStore&) = delete;
SpirvExpressionStore(SpirvExpressionStore&&) = delete;
~SpirvExpressionStore() = default;
@ -28,9 +29,9 @@ namespace Nz
void Store(ShaderAst::ExpressionPtr& node, UInt32 resultId);
using ExpressionVisitorExcept::Visit;
//void Visit(ShaderAst::AccessMemberExpression& node) override;
void Visit(ShaderAst::IdentifierExpression& node) override;
void Visit(ShaderAst::AccessMemberIndexExpression& node) override;
void Visit(ShaderAst::SwizzleExpression& node) override;
void Visit(ShaderAst::VariableExpression& node) override;
SpirvExpressionStore& operator=(const SpirvExpressionStore&) = delete;
SpirvExpressionStore& operator=(SpirvExpressionStore&&) = delete;
@ -44,9 +45,10 @@ namespace Nz
struct Pointer
{
SpirvStorageClass storage;
UInt32 resultId;
UInt32 pointerId;
};
SpirvAstVisitor& m_visitor;
SpirvBlock& m_block;
SpirvWriter& m_writer;
std::variant<std::monostate, LocalVar, Pointer> m_value;

View File

@ -7,9 +7,10 @@
namespace Nz
{
inline SpirvExpressionStore::SpirvExpressionStore(SpirvWriter& writer, SpirvBlock& block) :
inline SpirvExpressionStore::SpirvExpressionStore(SpirvWriter& writer, SpirvAstVisitor& visitor, SpirvBlock& block) :
m_block(block),
m_writer(writer)
m_writer(writer),
m_visitor(visitor)
{
}
}

View File

@ -9,7 +9,7 @@
#include <Nazara/Prerequisites.hpp>
#include <Nazara/Shader/Config.hpp>
#include <Nazara/Shader/ShaderConstantValue.hpp>
#include <Nazara/Shader/Ast/ConstantValue.hpp>
#include <Nazara/Shader/ShaderNodes.hpp>
#include <Nazara/Shader/ShaderWriter.hpp>
#include <Nazara/Shader/SpirvConstantCache.hpp>
@ -27,7 +27,6 @@ namespace Nz
friend class SpirvBlock;
friend class SpirvExpressionLoad;
friend class SpirvExpressionStore;
friend class SpirvVisitor;
public:
struct Environment;
@ -48,7 +47,6 @@ namespace Nz
};
private:
struct ExtVar;
struct FunctionParameter;
struct OnlyCache {};
@ -56,36 +54,21 @@ namespace Nz
void AppendHeader();
UInt32 GetConstantId(const ShaderConstantValue& value) const;
SpirvConstantCache::TypePtr BuildFunctionType(const ShaderAst::DeclareFunctionStatement& functionNode);
UInt32 GetConstantId(const ShaderAst::ConstantValue& value) const;
UInt32 GetExtVarPointerId(std::size_t varIndex) const;
UInt32 GetFunctionTypeId(const ShaderAst::DeclareFunctionStatement& functionNode);
const ExtVar& GetBuiltinVariable(ShaderAst::BuiltinEntry builtin) const;
const ExtVar& GetInputVariable(const std::string& name) const;
const ExtVar& GetOutputVariable(const std::string& name) const;
const ExtVar& GetUniformVariable(const std::string& name) const;
UInt32 GetPointerTypeId(const ShaderAst::ExpressionType& type, SpirvStorageClass storageClass) const;
UInt32 GetTypeId(const ShaderAst::ExpressionType& type) const;
inline bool IsConditionEnabled(const std::string& condition) const;
UInt32 ReadInputVariable(const std::string& name);
std::optional<UInt32> ReadInputVariable(const std::string& name, OnlyCache);
UInt32 ReadLocalVariable(const std::string& name);
std::optional<UInt32> ReadLocalVariable(const std::string& name, OnlyCache);
UInt32 ReadParameterVariable(const std::string& name);
std::optional<UInt32> ReadParameterVariable(const std::string& name, OnlyCache);
UInt32 ReadUniformVariable(const std::string& name);
std::optional<UInt32> ReadUniformVariable(const std::string& name, OnlyCache);
UInt32 ReadVariable(ExtVar& var);
std::optional<UInt32> ReadVariable(const ExtVar& var, OnlyCache);
UInt32 RegisterConstant(const ShaderConstantValue& value);
UInt32 RegisterConstant(const ShaderAst::ConstantValue& value);
UInt32 RegisterFunctionType(const ShaderAst::DeclareFunctionStatement& functionNode);
UInt32 RegisterPointerType(ShaderAst::ExpressionType type, SpirvStorageClass storageClass);
UInt32 RegisterType(ShaderAst::ExpressionType type);
void WriteLocalVariable(std::string name, UInt32 resultId);
static SpirvConstantCache::TypePtr BuildFunctionType(const ShaderAst::DeclareFunctionStatement& functionNode);
static void MergeSections(std::vector<UInt32>& output, const SpirvSection& from);
struct Context
@ -93,20 +76,6 @@ namespace Nz
const States* states = nullptr;
};
struct ExtVar
{
UInt32 pointerTypeId;
UInt32 typeId;
UInt32 varId;
std::optional<UInt32> valueId;
};
struct FunctionParameter
{
std::string name;
ShaderAst::ExpressionType type;
};
struct State;
Context m_context;

View File

@ -0,0 +1,225 @@
// Copyright (C) 2020 Jérôme Leclercq
// This file is part of the "Nazara Engine - Shader generator"
// For conditions of distribution and use, see copyright notice in Config.hpp
#include <Nazara/Shader/Ast/TransformVisitor.hpp>
#include <stdexcept>
#include <Nazara/Shader/Debug.hpp>
namespace Nz::ShaderAst
{
StatementPtr TransformVisitor::Transform(StatementPtr& nodePtr)
{
StatementPtr clone;
PushScope(); //< Global scope
{
clone = AstCloner::Clone(nodePtr);
}
PopScope();
return clone;
}
void TransformVisitor::Visit(BranchStatement& node)
{
for (auto& cond : node.condStatements)
{
PushScope();
{
cond.condition->Visit(*this);
cond.statement->Visit(*this);
}
PopScope();
}
if (node.elseStatement)
{
PushScope();
{
node.elseStatement->Visit(*this);
}
PopScope();
}
}
void TransformVisitor::Visit(ConditionalStatement& node)
{
PushScope();
{
AstCloner::Visit(node);
}
PopScope();
}
ExpressionType TransformVisitor::ResolveType(const ExpressionType& exprType)
{
return std::visit([&](auto&& arg) -> ExpressionType
{
using T = std::decay_t<decltype(arg)>;
if constexpr (std::is_same_v<T, NoType> ||
std::is_same_v<T, PrimitiveType> ||
std::is_same_v<T, MatrixType> ||
std::is_same_v<T, SamplerType> ||
std::is_same_v<T, StructType> ||
std::is_same_v<T, VectorType>)
{
return exprType;
}
else if constexpr (std::is_same_v<T, IdentifierType>)
{
const Identifier* identifier = FindIdentifier(arg.name);
assert(identifier);
assert(std::holds_alternative<Struct>(identifier->value));
return StructType{ std::get<Struct>(identifier->value).structIndex };
}
else if constexpr (std::is_same_v<T, UniformType>)
{
return std::visit([&](auto&& containedArg)
{
ExpressionType resolvedType = ResolveType(containedArg);
assert(std::holds_alternative<StructType>(resolvedType));
return UniformType{ std::get<StructType>(resolvedType) };
}, arg.containedType);
}
else
static_assert(AlwaysFalse<T>::value, "non-exhaustive visitor");
}, exprType);
}
void TransformVisitor::Visit(DeclareExternalStatement& node)
{
for (auto& extVar : node.externalVars)
{
extVar.type = ResolveType(extVar.type);
std::size_t varIndex = RegisterVariable(extVar.name);
if (!node.varIndex)
node.varIndex = varIndex;
}
AstCloner::Visit(node);
}
void TransformVisitor::Visit(DeclareFunctionStatement& node)
{
node.funcIndex = m_nextFuncIndex++;
node.returnType = ResolveType(node.returnType);
for (auto& parameter : node.parameters)
parameter.type = ResolveType(parameter.type);
PushScope();
{
for (auto& parameter : node.parameters)
{
std::size_t varIndex = RegisterVariable(parameter.name);
if (!node.varIndex)
node.varIndex = varIndex;
}
AstCloner::Visit(node);
}
PopScope();
}
void TransformVisitor::Visit(DeclareStructStatement& node)
{
node.structIndex = RegisterStruct(node.description.name, node.description);
AstCloner::Visit(node);
}
void TransformVisitor::Visit(DeclareVariableStatement& node)
{
node.varType = ResolveType(node.varType);
node.varIndex = RegisterVariable(node.varName);
AstCloner::Visit(node);
}
void TransformVisitor::Visit(MultiStatement& node)
{
PushScope();
{
AstCloner::Visit(node);
}
PopScope();
}
ExpressionPtr TransformVisitor::Clone(AccessMemberIdentifierExpression& node)
{
auto accessMemberIndex = std::make_unique<AccessMemberIndexExpression>();
accessMemberIndex->structExpr = CloneExpression(node.structExpr);
accessMemberIndex->cachedExpressionType = node.cachedExpressionType;
accessMemberIndex->memberIndices.resize(node.memberIdentifiers.size());
ExpressionType exprType = GetExpressionType(*node.structExpr);
for (std::size_t i = 0; i < node.memberIdentifiers.size(); ++i)
{
exprType = ResolveType(exprType);
assert(std::holds_alternative<StructType>(exprType));
std::size_t structIndex = std::get<StructType>(exprType).structIndex;
assert(structIndex < m_structs.size());
const StructDescription& structDesc = m_structs[structIndex];
auto it = std::find_if(structDesc.members.begin(), structDesc.members.end(), [&](const auto& member) { return member.name == node.memberIdentifiers[i]; });
assert(it != structDesc.members.end());
accessMemberIndex->memberIndices[i] = std::distance(structDesc.members.begin(), it);
exprType = it->type;
}
return accessMemberIndex;
}
ExpressionPtr TransformVisitor::Clone(CastExpression& node)
{
ExpressionPtr expr = AstCloner::Clone(node);
CastExpression* castExpr = static_cast<CastExpression*>(expr.get());
castExpr->targetType = ResolveType(castExpr->targetType);
return expr;
}
ExpressionPtr TransformVisitor::Clone(IdentifierExpression& node)
{
const Identifier* identifier = FindIdentifier(node.identifier);
assert(identifier);
assert(std::holds_alternative<Variable>(identifier->value));
auto varExpr = std::make_unique<VariableExpression>();
varExpr->cachedExpressionType = node.cachedExpressionType;
varExpr->variableId = std::get<Variable>(identifier->value).varIndex;
return varExpr;
}
ExpressionPtr TransformVisitor::CloneExpression(ExpressionPtr& expr)
{
ExpressionPtr exprPtr = AstCloner::CloneExpression(expr);
if (exprPtr)
{
assert(exprPtr->cachedExpressionType);
*exprPtr->cachedExpressionType = ResolveType(*exprPtr->cachedExpressionType);
}
return exprPtr;
}
void TransformVisitor::PushScope()
{
m_scopeSizes.push_back(m_identifiersInScope.size());
}
void TransformVisitor::PopScope()
{
assert(!m_scopeSizes.empty());
m_identifiersInScope.resize(m_scopeSizes.back());
m_scopeSizes.pop_back();
}
}

View File

@ -330,6 +330,11 @@ namespace Nz
}
}
void GlslWriter::Append(const ShaderAst::StructType& structType)
{
throw std::runtime_error("unexpected struct type");
}
void GlslWriter::Append(const ShaderAst::UniformType& uniformType)
{
/* TODO */
@ -371,6 +376,7 @@ namespace Nz
m_currentState->stream << param;
}
template<typename T1, typename T2, typename... Args>
void GlslWriter::Append(const T1& firstParam, const T2& secondParam, Args&&... params)
{
@ -595,7 +601,7 @@ namespace Nz
Append(")");
}
void GlslWriter::Visit(ShaderAst::AccessMemberExpression& node)
void GlslWriter::Visit(ShaderAst::AccessMemberIdentifierExpression& node)
{
Visit(node.structExpr, true);
@ -741,8 +747,6 @@ namespace Nz
void GlslWriter::Visit(ShaderAst::DeclareExternalStatement& node)
{
for (const auto& externalVar : node.externalVars)
{
std::optional<long long> bindingIndex;
@ -774,7 +778,7 @@ namespace Nz
EnterScope();
{
const Identifier* identifier = FindIdentifier(std::get<ShaderAst::UniformType>(externalVar.type).containedType.name);
const Identifier* identifier = FindIdentifier(std::get<ShaderAst::IdentifierType>(std::get<ShaderAst::UniformType>(externalVar.type).containedType).name);
assert(identifier);
assert(std::holds_alternative<ShaderAst::StructDescription>(identifier->value));

View File

@ -42,13 +42,25 @@ namespace Nz::ShaderAst
return PopStatement();
}
StatementPtr AstCloner::Clone(DeclareExternalStatement& node)
{
auto clone = std::make_unique<DeclareExternalStatement>();
clone->attributes = node.attributes;
clone->externalVars = node.externalVars;
clone->varIndex = node.varIndex;
return clone;
}
StatementPtr AstCloner::Clone(DeclareFunctionStatement& node)
{
auto clone = std::make_unique<DeclareFunctionStatement>();
clone->attributes = node.attributes;
clone->funcIndex = node.funcIndex;
clone->name = node.name;
clone->parameters = node.parameters;
clone->returnType = node.returnType;
clone->varIndex = node.varIndex;
clone->statements.reserve(node.statements.size());
for (auto& statement : node.statements)
@ -57,15 +69,95 @@ namespace Nz::ShaderAst
return clone;
}
void AstCloner::Visit(AccessMemberExpression& node)
StatementPtr AstCloner::Clone(DeclareStructStatement& node)
{
auto clone = std::make_unique<AccessMemberExpression>();
auto clone = std::make_unique<DeclareStructStatement>();
clone->structIndex = node.structIndex;
clone->description = node.description;
return clone;
}
StatementPtr AstCloner::Clone(DeclareVariableStatement& node)
{
auto clone = std::make_unique<DeclareVariableStatement>();
clone->varIndex = node.varIndex;
clone->varName = node.varName;
clone->varType = node.varType;
clone->initialExpression = CloneExpression(node.initialExpression);
return clone;
}
ExpressionPtr AstCloner::Clone(AccessMemberIdentifierExpression& node)
{
auto clone = std::make_unique<AccessMemberIdentifierExpression>();
clone->memberIdentifiers = node.memberIdentifiers;
clone->structExpr = CloneExpression(node.structExpr);
clone->cachedExpressionType = node.cachedExpressionType;
PushExpression(std::move(clone));
return clone;
}
ExpressionPtr AstCloner::Clone(AccessMemberIndexExpression& node)
{
auto clone = std::make_unique<AccessMemberIndexExpression>();
clone->memberIndices = node.memberIndices;
clone->structExpr = CloneExpression(node.structExpr);
clone->cachedExpressionType = node.cachedExpressionType;
return clone;
}
ExpressionPtr AstCloner::Clone(CastExpression& node)
{
auto clone = std::make_unique<CastExpression>();
clone->targetType = node.targetType;
std::size_t expressionCount = 0;
for (auto& expr : node.expressions)
{
if (!expr)
break;
clone->expressions[expressionCount++] = CloneExpression(expr);
}
clone->cachedExpressionType = node.cachedExpressionType;
return clone;
}
ExpressionPtr AstCloner::Clone(IdentifierExpression& node)
{
auto clone = std::make_unique<IdentifierExpression>();
clone->identifier = node.identifier;
clone->cachedExpressionType = node.cachedExpressionType;
return clone;
}
ExpressionPtr AstCloner::Clone(VariableExpression& node)
{
auto clone = std::make_unique<VariableExpression>();
clone->variableId = node.variableId;
clone->cachedExpressionType = node.cachedExpressionType;
return clone;
}
void AstCloner::Visit(AccessMemberIdentifierExpression& node)
{
return PushExpression(Clone(node));
}
void AstCloner::Visit(AccessMemberIndexExpression& node)
{
return PushExpression(Clone(node));
}
void AstCloner::Visit(AssignExpression& node)
@ -94,21 +186,7 @@ namespace Nz::ShaderAst
void AstCloner::Visit(CastExpression& node)
{
auto clone = std::make_unique<CastExpression>();
clone->targetType = node.targetType;
std::size_t expressionCount = 0;
for (auto& expr : node.expressions)
{
if (!expr)
break;
clone->expressions[expressionCount++] = CloneExpression(expr);
}
clone->cachedExpressionType = node.cachedExpressionType;
PushExpression(std::move(clone));
PushExpression(Clone(node));
}
void AstCloner::Visit(ConditionalExpression& node)
@ -135,12 +213,7 @@ namespace Nz::ShaderAst
void AstCloner::Visit(IdentifierExpression& node)
{
auto clone = std::make_unique<IdentifierExpression>();
clone->identifier = node.identifier;
clone->cachedExpressionType = node.cachedExpressionType;
PushExpression(std::move(clone));
PushExpression(Clone(node));
}
void AstCloner::Visit(IntrinsicExpression& node)
@ -169,6 +242,11 @@ namespace Nz::ShaderAst
PushExpression(std::move(clone));
}
void AstCloner::Visit(VariableExpression& node)
{
PushExpression(Clone(node));
}
void AstCloner::Visit(BranchStatement& node)
{
auto clone = std::make_unique<BranchStatement>();
@ -197,11 +275,7 @@ namespace Nz::ShaderAst
void AstCloner::Visit(DeclareExternalStatement& node)
{
auto clone = std::make_unique<DeclareExternalStatement>();
clone->attributes = node.attributes;
clone->externalVars = node.externalVars;
PushStatement(std::move(clone));
PushStatement(Clone(node));
}
void AstCloner::Visit(DeclareFunctionStatement& node)
@ -211,20 +285,12 @@ namespace Nz::ShaderAst
void AstCloner::Visit(DeclareStructStatement& node)
{
auto clone = std::make_unique<DeclareStructStatement>();
clone->description = node.description;
PushStatement(std::move(clone));
PushStatement(Clone(node));
}
void AstCloner::Visit(DeclareVariableStatement& node)
{
auto clone = std::make_unique<DeclareVariableStatement>();
clone->varName = node.varName;
clone->varType = node.varType;
clone->initialExpression = CloneExpression(node.initialExpression);
PushStatement(std::move(clone));
PushStatement(Clone(node));
}
void AstCloner::Visit(DiscardStatement& /*node*/)

View File

@ -7,7 +7,12 @@
namespace Nz::ShaderAst
{
void AstRecursiveVisitor::Visit(AccessMemberExpression& node)
void AstRecursiveVisitor::Visit(AccessMemberIdentifierExpression& node)
{
node.structExpr->Visit(*this);
}
void AstRecursiveVisitor::Visit(AccessMemberIndexExpression& node)
{
node.structExpr->Visit(*this);
}
@ -62,6 +67,11 @@ namespace Nz::ShaderAst
node.expression->Visit(*this);
}
void AstRecursiveVisitor::Visit(VariableExpression& node)
{
/* Nothing to do */
}
void AstRecursiveVisitor::Visit(BranchStatement& node)
{
for (auto& cond : node.condStatements)

View File

@ -53,7 +53,7 @@ namespace Nz::ShaderAst
{
ExpressionType subType = extVar.type;
if (IsUniformType(subType))
subType = IdentifierType{ std::get<UniformType>(subType).containedType };
subType = std::get<IdentifierType>(std::get<UniformType>(subType).containedType);
RegisterVariable(extVar.name, std::move(subType));
}

View File

@ -33,7 +33,7 @@ namespace Nz::ShaderAst
};
}
void AstSerializerBase::Serialize(AccessMemberExpression& node)
void AstSerializerBase::Serialize(AccessMemberIdentifierExpression& node)
{
Node(node.structExpr);
@ -42,6 +42,15 @@ namespace Nz::ShaderAst
Value(identifier);
}
void AstSerializerBase::Serialize(AccessMemberIndexExpression& node)
{
Node(node.structExpr);
Container(node.memberIndices);
for (std::size_t& identifier : node.memberIndices)
SizeT(identifier);
}
void AstSerializerBase::Serialize(AssignExpression& node)
{
Enum(node.op);
@ -133,6 +142,11 @@ namespace Nz::ShaderAst
Enum(node.components[i]);
}
void AstSerializerBase::Serialize(VariableExpression& node)
{
SizeT(node.variableId);
}
void AstSerializerBase::Serialize(BranchStatement& node)
{
@ -364,14 +378,19 @@ namespace Nz::ShaderAst
m_stream << UInt32(arg.dim);
m_stream << UInt32(arg.sampledType);
}
else if constexpr (std::is_same_v<T, UniformType>)
else if constexpr (std::is_same_v<T, StructType>)
{
m_stream << UInt8(5);
m_stream << arg.containedType.name;
m_stream << UInt32(arg.structIndex);
}
else if constexpr (std::is_same_v<T, UniformType>)
{
m_stream << UInt8(6);
m_stream << std::get<IdentifierType>(arg.containedType).name;
}
else if constexpr (std::is_same_v<T, VectorType>)
{
m_stream << UInt8(6);
m_stream << UInt8(7);
m_stream << UInt32(arg.componentCount);
m_stream << UInt32(arg.type);
}
@ -621,7 +640,18 @@ namespace Nz::ShaderAst
break;
}
case 5: //< UniformType
case 5: //< StructType
{
UInt32 structIndex;
Value(structIndex);
type = StructType{
structIndex
};
break;
}
case 6: //< UniformType
{
std::string containedType;
Value(containedType);
@ -634,7 +664,7 @@ namespace Nz::ShaderAst
break;
}
case 6: //< VectorType
case 7: //< VectorType
{
UInt32 componentCount;
PrimitiveType componentType;

View File

@ -13,7 +13,12 @@ namespace Nz::ShaderAst
return m_expressionCategory;
}
void ShaderAstValueCategory::Visit(AccessMemberExpression& node)
void ShaderAstValueCategory::Visit(AccessMemberIdentifierExpression& node)
{
node.structExpr->Visit(*this);
}
void ShaderAstValueCategory::Visit(AccessMemberIndexExpression& node)
{
node.structExpr->Visit(*this);
}
@ -66,4 +71,9 @@ namespace Nz::ShaderAst
{
node.expression->Visit(*this);
}
void ShaderAstValueCategory::Visit(VariableExpression& node)
{
m_expressionCategory = ExpressionCategory::LValue;
}
}

View File

@ -131,7 +131,7 @@ namespace Nz::ShaderAst
return expressionType;
}
void AstValidator::Visit(AccessMemberExpression& node)
void AstValidator::Visit(AccessMemberIdentifierExpression& node)
{
// Register expressions types
AstScopedVisitor::Visit(node);
@ -351,7 +351,7 @@ namespace Nz::ShaderAst
if (!exprPtr)
break;
ExpressionType exprType = GetExpressionType(*exprPtr);
const ExpressionType& exprType = GetExpressionType(*exprPtr);
if (!IsPrimitiveType(exprType) && !IsVectorType(exprType))
throw AstError{ "incompatible type" };
@ -552,14 +552,17 @@ namespace Nz::ShaderAst
void AstValidator::Visit(DeclareExternalStatement& node)
{
for (const auto& [attributeType, arg] : node.attributes)
if (!node.attributes.empty())
throw AstError{ "unhandled attribute for external block" };
/*for (const auto& [attributeType, arg] : node.attributes)
{
switch (attributeType)
{
default:
throw AstError{ "unhandled attribute for external block" };
}
}
}*/
for (const auto& extVar : node.externalVars)
{

View File

@ -602,7 +602,7 @@ namespace Nz::ShaderLang
if (Peek().type == TokenType::Dot)
{
std::unique_ptr<ShaderAst::AccessMemberExpression> accessMemberNode = std::make_unique<ShaderAst::AccessMemberExpression>();
std::unique_ptr<ShaderAst::AccessMemberIdentifierExpression> accessMemberNode = std::make_unique<ShaderAst::AccessMemberIdentifierExpression>();
accessMemberNode->structExpr = std::move(identifierExpr);
do
@ -685,9 +685,9 @@ namespace Nz::ShaderLang
if (IsVariableInScope(identifier))
{
auto node = ParseIdentifier();
if (node->GetType() == ShaderAst::NodeType::AccessMemberExpression)
if (node->GetType() == ShaderAst::NodeType::AccessMemberIdentifierExpression)
{
ShaderAst::AccessMemberExpression* memberExpr = static_cast<ShaderAst::AccessMemberExpression*>(node.get());
ShaderAst::AccessMemberIdentifierExpression* memberExpr = static_cast<ShaderAst::AccessMemberIdentifierExpression*>(node.get());
if (!memberExpr->memberIdentifiers.empty() && memberExpr->memberIdentifiers.front() == "Sample")
{
if (Peek().type == TokenType::OpenParenthesis)

View File

@ -3,6 +3,7 @@
// For conditions of distribution and use, see copyright notice in Config.hpp
#include <Nazara/Shader/SpirvAstVisitor.hpp>
#include <Nazara/Core/CallOnExit.hpp>
#include <Nazara/Core/StackVector.hpp>
#include <Nazara/Shader/SpirvSection.hpp>
#include <Nazara/Shader/SpirvExpressionLoad.hpp>
@ -12,6 +13,11 @@
namespace Nz
{
UInt32 SpirvAstVisitor::AllocateResultId()
{
return m_writer.AllocateResultId();
}
UInt32 SpirvAstVisitor::EvaluateExpression(ShaderAst::ExpressionPtr& expr)
{
expr->Visit(*this);
@ -20,9 +26,16 @@ namespace Nz
return PopResultId();
}
void SpirvAstVisitor::Visit(ShaderAst::AccessMemberExpression& node)
auto SpirvAstVisitor::GetVariable(std::size_t varIndex) const -> const Variable&
{
SpirvExpressionLoad accessMemberVisitor(m_writer, *m_currentBlock);
assert(varIndex < m_variables.size());
assert(m_variables[varIndex]);
return *m_variables[varIndex];
}
void SpirvAstVisitor::Visit(ShaderAst::AccessMemberIndexExpression& node)
{
SpirvExpressionLoad accessMemberVisitor(m_writer, *this, *m_currentBlock);
PushResultId(accessMemberVisitor.Evaluate(node));
}
@ -30,7 +43,7 @@ namespace Nz
{
UInt32 resultId = EvaluateExpression(node.right);
SpirvExpressionStore storeVisitor(m_writer, *m_currentBlock);
SpirvExpressionStore storeVisitor(m_writer, *this, *m_currentBlock);
storeVisitor.Store(node.left, resultId);
PushResultId(resultId);
@ -38,18 +51,24 @@ namespace Nz
void SpirvAstVisitor::Visit(ShaderAst::BinaryExpression& node)
{
ShaderAst::ExpressionType resultExprType = GetExpressionType(node);
assert(IsPrimitiveType(resultExprType));
auto RetrieveBaseType = [](const ShaderAst::ExpressionType& exprType)
{
if (IsPrimitiveType(exprType))
return std::get<ShaderAst::PrimitiveType>(exprType);
else if (IsVectorType(exprType))
return std::get<ShaderAst::VectorType>(exprType).type;
else if (IsMatrixType(exprType))
return std::get<ShaderAst::MatrixType>(exprType).type;
else
throw std::runtime_error("unexpected type");
};
ShaderAst::ExpressionType leftExprType = GetExpressionType(*node.left);
assert(IsPrimitiveType(leftExprType));
const ShaderAst::ExpressionType& resultType = GetExpressionType(node);
const ShaderAst::ExpressionType& leftType = GetExpressionType(*node.left);
const ShaderAst::ExpressionType& rightType = GetExpressionType(*node.right);
ShaderAst::ExpressionType rightExprType = GetExpressionType(*node.right);
assert(IsPrimitiveType(rightExprType));
ShaderAst::PrimitiveType resultType = std::get<ShaderAst::PrimitiveType>(resultExprType);
ShaderAst::PrimitiveType leftType = std::get<ShaderAst::PrimitiveType>(leftExprType);
ShaderAst::PrimitiveType rightType = std::get<ShaderAst::PrimitiveType>(rightExprType);
ShaderAst::PrimitiveType leftTypeBase = RetrieveBaseType(leftType);
ShaderAst::PrimitiveType rightTypeBase = RetrieveBaseType(rightType);
UInt32 leftOperand = EvaluateExpression(node.left);
@ -64,28 +83,16 @@ namespace Nz
{
case ShaderAst::BinaryType::Add:
{
switch (leftType)
switch (leftTypeBase)
{
case ShaderAst::PrimitiveType::Float32:
// case ShaderAst::PrimitiveType::Float2:
// case ShaderAst::PrimitiveType::Float3:
// case ShaderAst::PrimitiveType::Float4:
// case ShaderAst::PrimitiveType::Mat4x4:
return SpirvOp::OpFAdd;
case ShaderAst::PrimitiveType::Int32:
// case ShaderAst::PrimitiveType::Int2:
// case ShaderAst::PrimitiveType::Int3:
// case ShaderAst::PrimitiveType::Int4:
case ShaderAst::PrimitiveType::UInt32:
// case ShaderAst::PrimitiveType::UInt2:
// case ShaderAst::PrimitiveType::UInt3:
// case ShaderAst::PrimitiveType::UInt4:
return SpirvOp::OpIAdd;
case ShaderAst::PrimitiveType::Boolean:
// case ShaderAst::PrimitiveType::Sampler2D:
// case ShaderAst::PrimitiveType::Void:
break;
}
@ -94,28 +101,16 @@ namespace Nz
case ShaderAst::BinaryType::Subtract:
{
switch (leftType)
switch (leftTypeBase)
{
case ShaderAst::PrimitiveType::Float32:
// case ShaderAst::PrimitiveType::Float2:
// case ShaderAst::PrimitiveType::Float3:
// case ShaderAst::PrimitiveType::Float4:
// case ShaderAst::PrimitiveType::Mat4x4:
return SpirvOp::OpFSub;
case ShaderAst::PrimitiveType::Int32:
// case ShaderAst::PrimitiveType::Int2:
// case ShaderAst::PrimitiveType::Int3:
// case ShaderAst::PrimitiveType::Int4:
case ShaderAst::PrimitiveType::UInt32:
// case ShaderAst::PrimitiveType::UInt2:
// case ShaderAst::PrimitiveType::UInt3:
// case ShaderAst::PrimitiveType::UInt4:
return SpirvOp::OpISub;
case ShaderAst::PrimitiveType::Boolean:
// case ShaderAst::PrimitiveType::Sampler2D:
// case ShaderAst::PrimitiveType::Void:
break;
}
@ -124,30 +119,18 @@ namespace Nz
case ShaderAst::BinaryType::Divide:
{
switch (leftType)
switch (leftTypeBase)
{
case ShaderAst::PrimitiveType::Float32:
// case ShaderAst::PrimitiveType::Float2:
// case ShaderAst::PrimitiveType::Float3:
// case ShaderAst::PrimitiveType::Float4:
// case ShaderAst::PrimitiveType::Mat4x4:
return SpirvOp::OpFDiv;
case ShaderAst::PrimitiveType::Int32:
// case ShaderAst::PrimitiveType::Int2:
// case ShaderAst::PrimitiveType::Int3:
// case ShaderAst::PrimitiveType::Int4:
return SpirvOp::OpSDiv;
case ShaderAst::PrimitiveType::UInt32:
// case ShaderAst::PrimitiveType::UInt2:
// case ShaderAst::PrimitiveType::UInt3:
// case ShaderAst::PrimitiveType::UInt4:
return SpirvOp::OpUDiv;
case ShaderAst::PrimitiveType::Boolean:
// case ShaderAst::PrimitiveType::Sampler2D:
// case ShaderAst::PrimitiveType::Void:
break;
}
@ -156,31 +139,17 @@ namespace Nz
case ShaderAst::BinaryType::CompEq:
{
switch (leftType)
switch (leftTypeBase)
{
case ShaderAst::PrimitiveType::Boolean:
return SpirvOp::OpLogicalEqual;
case ShaderAst::PrimitiveType::Float32:
// case ShaderAst::PrimitiveType::Float2:
// case ShaderAst::PrimitiveType::Float3:
// case ShaderAst::PrimitiveType::Float4:
// case ShaderAst::PrimitiveType::Mat4x4:
return SpirvOp::OpFOrdEqual;
case ShaderAst::PrimitiveType::Int32:
// case ShaderAst::PrimitiveType::Int2:
// case ShaderAst::PrimitiveType::Int3:
// case ShaderAst::PrimitiveType::Int4:
case ShaderAst::PrimitiveType::UInt32:
// case ShaderAst::PrimitiveType::UInt2:
// case ShaderAst::PrimitiveType::UInt3:
// case ShaderAst::PrimitiveType::UInt4:
return SpirvOp::OpIEqual;
// case ShaderAst::PrimitiveType::Sampler2D:
// case ShaderAst::PrimitiveType::Void:
// break;
}
break;
@ -188,30 +157,18 @@ namespace Nz
case ShaderAst::BinaryType::CompGe:
{
switch (leftType)
switch (leftTypeBase)
{
case ShaderAst::PrimitiveType::Float32:
// case ShaderAst::PrimitiveType::Float2:
// case ShaderAst::PrimitiveType::Float3:
// case ShaderAst::PrimitiveType::Float4:
// case ShaderAst::PrimitiveType::Mat4x4:
return SpirvOp::OpFOrdGreaterThan;
case ShaderAst::PrimitiveType::Int32:
// case ShaderAst::PrimitiveType::Int2:
// case ShaderAst::PrimitiveType::Int3:
// case ShaderAst::PrimitiveType::Int4:
return SpirvOp::OpSGreaterThan;
case ShaderAst::PrimitiveType::UInt32:
// case ShaderAst::PrimitiveType::UInt2:
// case ShaderAst::PrimitiveType::UInt3:
// case ShaderAst::PrimitiveType::UInt4:
return SpirvOp::OpUGreaterThan;
case ShaderAst::PrimitiveType::Boolean:
// case ShaderAst::PrimitiveType::Sampler2D:
// case ShaderAst::PrimitiveType::Void:
break;
}
@ -220,30 +177,18 @@ namespace Nz
case ShaderAst::BinaryType::CompGt:
{
switch (leftType)
switch (leftTypeBase)
{
case ShaderAst::PrimitiveType::Float32:
// case ShaderAst::PrimitiveType::Float2:
// case ShaderAst::PrimitiveType::Float3:
// case ShaderAst::PrimitiveType::Float4:
// case ShaderAst::PrimitiveType::Mat4x4:
return SpirvOp::OpFOrdGreaterThanEqual;
case ShaderAst::PrimitiveType::Int32:
// case ShaderAst::PrimitiveType::Int2:
// case ShaderAst::PrimitiveType::Int3:
// case ShaderAst::PrimitiveType::Int4:
return SpirvOp::OpSGreaterThanEqual;
case ShaderAst::PrimitiveType::UInt32:
// case ShaderAst::PrimitiveType::UInt2:
// case ShaderAst::PrimitiveType::UInt3:
// case ShaderAst::PrimitiveType::UInt4:
return SpirvOp::OpUGreaterThanEqual;
case ShaderAst::PrimitiveType::Boolean:
// case ShaderAst::PrimitiveType::Sampler2D:
// case ShaderAst::PrimitiveType::Void:
break;
}
@ -252,30 +197,18 @@ namespace Nz
case ShaderAst::BinaryType::CompLe:
{
switch (leftType)
switch (leftTypeBase)
{
case ShaderAst::PrimitiveType::Float32:
// case ShaderAst::PrimitiveType::Float2:
// case ShaderAst::PrimitiveType::Float3:
// case ShaderAst::PrimitiveType::Float4:
// case ShaderAst::PrimitiveType::Mat4x4:
return SpirvOp::OpFOrdLessThanEqual;
case ShaderAst::PrimitiveType::Int32:
// case ShaderAst::PrimitiveType::Int2:
// case ShaderAst::PrimitiveType::Int3:
// case ShaderAst::PrimitiveType::Int4:
return SpirvOp::OpSLessThanEqual;
case ShaderAst::PrimitiveType::UInt32:
// case ShaderAst::PrimitiveType::UInt2:
// case ShaderAst::PrimitiveType::UInt3:
// case ShaderAst::PrimitiveType::UInt4:
return SpirvOp::OpULessThanEqual;
case ShaderAst::PrimitiveType::Boolean:
// case ShaderAst::PrimitiveType::Sampler2D:
// case ShaderAst::PrimitiveType::Void:
break;
}
@ -284,30 +217,18 @@ namespace Nz
case ShaderAst::BinaryType::CompLt:
{
switch (leftType)
switch (leftTypeBase)
{
case ShaderAst::PrimitiveType::Float32:
// case ShaderAst::PrimitiveType::Float2:
// case ShaderAst::PrimitiveType::Float3:
// case ShaderAst::PrimitiveType::Float4:
// case ShaderAst::PrimitiveType::Mat4x4:
return SpirvOp::OpFOrdLessThan;
case ShaderAst::PrimitiveType::Int32:
// case ShaderAst::PrimitiveType::Int2:
// case ShaderAst::PrimitiveType::Int3:
// case ShaderAst::PrimitiveType::Int4:
return SpirvOp::OpSLessThan;
case ShaderAst::PrimitiveType::UInt32:
// case ShaderAst::PrimitiveType::UInt2:
// case ShaderAst::PrimitiveType::UInt3:
// case ShaderAst::PrimitiveType::UInt4:
return SpirvOp::OpULessThan;
case ShaderAst::PrimitiveType::Boolean:
// case ShaderAst::PrimitiveType::Sampler2D:
// case ShaderAst::PrimitiveType::Void:
break;
}
@ -316,31 +237,17 @@ namespace Nz
case ShaderAst::BinaryType::CompNe:
{
switch (leftType)
switch (leftTypeBase)
{
case ShaderAst::PrimitiveType::Boolean:
return SpirvOp::OpLogicalNotEqual;
case ShaderAst::PrimitiveType::Float32:
// case ShaderAst::PrimitiveType::Float2:
// case ShaderAst::PrimitiveType::Float3:
// case ShaderAst::PrimitiveType::Float4:
// case ShaderAst::PrimitiveType::Mat4x4:
return SpirvOp::OpFOrdNotEqual;
case ShaderAst::PrimitiveType::Int32:
// case ShaderAst::PrimitiveType::Int2:
// case ShaderAst::PrimitiveType::Int3:
// case ShaderAst::PrimitiveType::Int4:
case ShaderAst::PrimitiveType::UInt32:
// case ShaderAst::PrimitiveType::UInt2:
// case ShaderAst::PrimitiveType::UInt3:
// case ShaderAst::PrimitiveType::UInt4:
return SpirvOp::OpINotEqual;
// case ShaderAst::PrimitiveType::Sampler2D:
// case ShaderAst::PrimitiveType::Void:
// break;
}
break;
@ -348,81 +255,51 @@ namespace Nz
case ShaderAst::BinaryType::Multiply:
{
switch (leftType)
switch (leftTypeBase)
{
case ShaderAst::PrimitiveType::Float32:
{
switch (rightType)
if (IsPrimitiveType(leftType))
{
case ShaderAst::PrimitiveType::Float32:
return SpirvOp::OpFMul;
// case ShaderAst::PrimitiveType::Float2:
// case ShaderAst::PrimitiveType::Float3:
// case ShaderAst::PrimitiveType::Float4:
// swapOperands = true;
// return SpirvOp::OpVectorTimesScalar;
//
// case ShaderAst::PrimitiveType::Mat4x4:
// swapOperands = true;
// return SpirvOp::OpMatrixTimesScalar;
default:
break;
// Handle float * matrix|vector as matrix|vector * float
if (IsMatrixType(rightType))
{
swapOperands = true;
return SpirvOp::OpMatrixTimesScalar;
}
else if (IsVectorType(rightType))
{
swapOperands = true;
return SpirvOp::OpVectorTimesScalar;
}
}
else if (IsPrimitiveType(rightType))
{
if (IsMatrixType(leftType))
return SpirvOp::OpMatrixTimesScalar;
else if (IsVectorType(leftType))
return SpirvOp::OpVectorTimesScalar;
}
else if (IsMatrixType(leftType))
{
if (IsMatrixType(rightType))
return SpirvOp::OpMatrixTimesMatrix;
else if (IsVectorType(rightType))
return SpirvOp::OpMatrixTimesVector;
}
else if (IsMatrixType(rightType))
{
assert(IsVectorType(leftType));
return SpirvOp::OpVectorTimesMatrix;
}
break;
return SpirvOp::OpFMul;
}
// case ShaderAst::PrimitiveType::Float2:
// case ShaderAst::PrimitiveType::Float3:
// case ShaderAst::PrimitiveType::Float4:
// {
// switch (rightType)
// {
// case ShaderAst::PrimitiveType::Float32:
// return SpirvOp::OpVectorTimesScalar;
//
// case ShaderAst::PrimitiveType::Float2:
// case ShaderAst::PrimitiveType::Float3:
// case ShaderAst::PrimitiveType::Float4:
// return SpirvOp::OpFMul;
//
// case ShaderAst::PrimitiveType::Mat4x4:
// return SpirvOp::OpVectorTimesMatrix;
//
// default:
// break;
// }
//
// break;
// }
case ShaderAst::PrimitiveType::Int32:
// case ShaderAst::PrimitiveType::Int2:
// case ShaderAst::PrimitiveType::Int3:
// case ShaderAst::PrimitiveType::Int4:
case ShaderAst::PrimitiveType::UInt32:
// case ShaderAst::PrimitiveType::UInt2:
// case ShaderAst::PrimitiveType::UInt3:
// case ShaderAst::PrimitiveType::UInt4:
return SpirvOp::OpIMul;
// case ShaderAst::PrimitiveType::Mat4x4:
// {
// switch (rightType)
// {
// case ShaderAst::PrimitiveType::Float32: return SpirvOp::OpMatrixTimesScalar;
// case ShaderAst::PrimitiveType::Float4: return SpirvOp::OpMatrixTimesVector;
// case ShaderAst::PrimitiveType::Mat4x4: return SpirvOp::OpMatrixTimesMatrix;
//
// default:
// break;
// }
//
// break;
// }
default:
break;
}
@ -454,7 +331,7 @@ namespace Nz
firstCond.statement->Visit(*this);
SpirvBlock mergeBlock(m_writer);
m_blocks.back().Append(SpirvOp::OpSelectionMerge, mergeBlock.GetLabelId(), SpirvSelectionControl::None);
m_functionBlocks.back().Append(SpirvOp::OpSelectionMerge, mergeBlock.GetLabelId(), SpirvSelectionControl::None);
std::optional<std::size_t> nextBlock;
for (std::size_t statementIndex = 1; statementIndex < node.condStatements.size(); ++statementIndex)
@ -463,10 +340,10 @@ namespace Nz
SpirvBlock contentBlock(m_writer);
m_blocks.back().Append(SpirvOp::OpBranchConditional, previousConditionId, previousContentBlock.GetLabelId(), contentBlock.GetLabelId());
m_functionBlocks.back().Append(SpirvOp::OpBranchConditional, previousConditionId, previousContentBlock.GetLabelId(), contentBlock.GetLabelId());
previousConditionId = EvaluateExpression(statement.condition);
m_blocks.emplace_back(std::move(previousContentBlock));
m_functionBlocks.emplace_back(std::move(previousContentBlock));
previousContentBlock = std::move(contentBlock);
m_currentBlock = &previousContentBlock;
@ -479,54 +356,148 @@ namespace Nz
SpirvBlock elseBlock(m_writer);
m_currentBlock = &elseBlock;
node.elseStatement->Visit(*this);
elseBlock.Append(SpirvOp::OpBranch, mergeBlock.GetLabelId()); //< FIXME: Shouldn't terminate twice
m_blocks.back().Append(SpirvOp::OpBranchConditional, previousConditionId, previousContentBlock.GetLabelId(), elseBlock.GetLabelId());
m_blocks.emplace_back(std::move(previousContentBlock));
m_blocks.emplace_back(std::move(elseBlock));
m_functionBlocks.back().Append(SpirvOp::OpBranchConditional, previousConditionId, previousContentBlock.GetLabelId(), elseBlock.GetLabelId());
m_functionBlocks.emplace_back(std::move(previousContentBlock));
m_functionBlocks.emplace_back(std::move(elseBlock));
}
else
{
m_blocks.back().Append(SpirvOp::OpBranchConditional, previousConditionId, previousContentBlock.GetLabelId(), mergeBlock.GetLabelId());
m_blocks.emplace_back(std::move(previousContentBlock));
m_functionBlocks.back().Append(SpirvOp::OpBranchConditional, previousConditionId, previousContentBlock.GetLabelId(), mergeBlock.GetLabelId());
m_functionBlocks.emplace_back(std::move(previousContentBlock));
}
m_blocks.emplace_back(std::move(mergeBlock));
m_functionBlocks.emplace_back(std::move(mergeBlock));
m_currentBlock = &m_blocks.back();
m_currentBlock = &m_functionBlocks.back();
}
void SpirvAstVisitor::Visit(ShaderAst::CastExpression& node)
{
const ShaderAst::ExpressionType& targetExprType = node.targetType;
assert(IsPrimitiveType(targetExprType));
ShaderAst::PrimitiveType targetType = std::get<ShaderAst::PrimitiveType>(targetExprType);
StackVector<UInt32> exprResults = NazaraStackVector(UInt32, node.expressions.size());
for (auto& exprPtr : node.expressions)
if (IsPrimitiveType(targetExprType))
{
if (!exprPtr)
break;
ShaderAst::PrimitiveType targetType = std::get<ShaderAst::PrimitiveType>(targetExprType);
exprResults.push_back(EvaluateExpression(exprPtr));
assert(node.expressions[0] && !node.expressions[1]);
ShaderAst::ExpressionPtr& expression = node.expressions[0];
assert(expression->cachedExpressionType.has_value());
const ShaderAst::ExpressionType& exprType = expression->cachedExpressionType.value();
assert(IsPrimitiveType(exprType));
ShaderAst::PrimitiveType fromType = std::get<ShaderAst::PrimitiveType>(exprType);
UInt32 fromId = EvaluateExpression(expression);
if (targetType == fromType)
return PushResultId(fromId);
std::optional<SpirvOp> castOp;
switch (targetType)
{
case ShaderAst::PrimitiveType::Boolean:
throw std::runtime_error("unsupported cast to boolean");
case ShaderAst::PrimitiveType::Float32:
{
switch (fromType)
{
case ShaderAst::PrimitiveType::Boolean:
throw std::runtime_error("unsupported cast from boolean");
case ShaderAst::PrimitiveType::Float32:
break; //< Already handled
case ShaderAst::PrimitiveType::Int32:
castOp = SpirvOp::OpConvertSToF;
break;
case ShaderAst::PrimitiveType::UInt32:
castOp = SpirvOp::OpConvertUToF;
break;
}
break;
}
case ShaderAst::PrimitiveType::Int32:
{
switch (fromType)
{
case ShaderAst::PrimitiveType::Boolean:
throw std::runtime_error("unsupported cast from boolean");
case ShaderAst::PrimitiveType::Float32:
castOp = SpirvOp::OpConvertFToS;
break;
case ShaderAst::PrimitiveType::Int32:
break; //< Already handled
case ShaderAst::PrimitiveType::UInt32:
castOp = SpirvOp::OpSConvert;
break;
}
break;
}
case ShaderAst::PrimitiveType::UInt32:
{
switch (fromType)
{
case ShaderAst::PrimitiveType::Boolean:
throw std::runtime_error("unsupported cast from boolean");
case ShaderAst::PrimitiveType::Float32:
castOp = SpirvOp::OpConvertFToU;
break;
case ShaderAst::PrimitiveType::Int32:
castOp = SpirvOp::OpUConvert;
break;
case ShaderAst::PrimitiveType::UInt32:
break; //< Already handled
}
break;
}
}
assert(castOp);
UInt32 resultId = m_writer.AllocateResultId();
m_currentBlock->Append(*castOp, m_writer.GetTypeId(targetType), resultId, fromId);
throw std::runtime_error("toudou");
}
UInt32 resultId = m_writer.AllocateResultId();
m_currentBlock->AppendVariadic(SpirvOp::OpCompositeConstruct, [&](const auto& appender)
else
{
appender(m_writer.GetTypeId(targetType));
appender(resultId);
assert(IsVectorType(targetExprType));
StackVector<UInt32> exprResults = NazaraStackVector(UInt32, node.expressions.size());
for (UInt32 exprResultId : exprResults)
appender(exprResultId);
});
for (auto& exprPtr : node.expressions)
{
if (!exprPtr)
break;
PushResultId(resultId);
exprResults.push_back(EvaluateExpression(exprPtr));
}
UInt32 resultId = m_writer.AllocateResultId();
m_currentBlock->AppendVariadic(SpirvOp::OpCompositeConstruct, [&](const auto& appender)
{
appender(m_writer.GetTypeId(targetExprType));
appender(resultId);
for (UInt32 exprResultId : exprResults)
appender(exprResultId);
});
PushResultId(resultId);
}
}
void SpirvAstVisitor::Visit(ShaderAst::ConditionalExpression& node)
@ -551,10 +522,108 @@ namespace Nz
}, node.value);
}
void SpirvAstVisitor::Visit(ShaderAst::DeclareExternalStatement& node)
{
assert(node.varIndex);
std::size_t varIndex = *node.varIndex;
for (auto&& extVar : node.externalVars)
RegisterExternalVariable(varIndex++, extVar.type);
}
void SpirvAstVisitor::Visit(ShaderAst::DeclareFunctionStatement& node)
{
assert(node.funcIndex);
m_funcIndex = *node.funcIndex;
auto& func = m_funcData[m_funcIndex];
func.funcId = m_writer.AllocateResultId();
m_instructions.Append(SpirvOp::OpFunction, func.returnTypeId, func.funcId, 0, func.funcTypeId);
if (!func.parameters.empty())
{
std::size_t varIndex = *node.varIndex;
for (const auto& param : func.parameters)
{
UInt32 paramResultId = m_writer.AllocateResultId();
m_instructions.Append(SpirvOp::OpFunctionParameter, param.typeId, paramResultId);
RegisterVariable(varIndex++, param.typeId, paramResultId, SpirvStorageClass::Function);
}
}
m_functionBlocks.clear();
m_currentBlock = &m_functionBlocks.emplace_back(m_writer);
CallOnExit resetCurrentBlock([&] { m_currentBlock = nullptr; });
for (auto& var : func.variables)
{
var.varId = m_writer.AllocateResultId();
m_currentBlock->Append(SpirvOp::OpVariable, var.typeId, var.varId, SpirvStorageClass::Function);
}
if (func.entryPointData)
{
auto& entryPointData = *func.entryPointData;
if (entryPointData.inputStruct)
{
auto& inputStruct = *entryPointData.inputStruct;
std::size_t varIndex = *node.varIndex;
UInt32 paramId = m_writer.AllocateResultId();
m_currentBlock->Append(SpirvOp::OpVariable, inputStruct.pointerId, paramId, SpirvStorageClass::Function);
for (const auto& input : entryPointData.inputs)
{
UInt32 resultId = m_writer.AllocateResultId();
m_currentBlock->Append(SpirvOp::OpAccessChain, input.memberPointerId, resultId, paramId, input.memberIndexConstantId);
m_currentBlock->Append(SpirvOp::OpCopyMemory, resultId, input.varId);
}
RegisterVariable(varIndex, inputStruct.typeId, paramId, SpirvStorageClass::Function);
}
}
for (auto& statementPtr : node.statements)
statementPtr->Visit(*this);
// Add implicit return
if (!m_functionBlocks.back().IsTerminated())
m_functionBlocks.back().Append(SpirvOp::OpReturn);
for (SpirvBlock& block : m_functionBlocks)
m_instructions.AppendSection(block);
m_instructions.Append(SpirvOp::OpFunctionEnd);
}
void SpirvAstVisitor::Visit(ShaderAst::DeclareStructStatement& node)
{
assert(node.structIndex);
RegisterStruct(*node.structIndex, node.description);
}
void SpirvAstVisitor::Visit(ShaderAst::DeclareVariableStatement& node)
{
const auto& func = m_funcData[m_funcIndex];
UInt32 pointerTypeId = m_writer.GetPointerTypeId(node.varType, SpirvStorageClass::Function);
UInt32 typeId = m_writer.GetTypeId(node.varType);
assert(node.varIndex);
auto varIt = func.varIndexToVarId.find(*node.varIndex);
UInt32 varId = func.variables[varIt->second].varId;
RegisterVariable(*node.varIndex, typeId, varId, SpirvStorageClass::Function);
if (node.initialExpression)
m_writer.WriteLocalVariable(node.varName, EvaluateExpression(node.initialExpression));
{
UInt32 value = EvaluateExpression(node.initialExpression);
m_currentBlock->Append(SpirvOp::OpStore, varId, value);
}
}
void SpirvAstVisitor::Visit(ShaderAst::DiscardStatement& /*node*/)
@ -569,19 +638,13 @@ namespace Nz
PopResultId();
}
void SpirvAstVisitor::Visit(ShaderAst::IdentifierExpression& node)
{
SpirvExpressionLoad loadVisitor(m_writer, *m_currentBlock);
PushResultId(loadVisitor.Evaluate(node));
}
void SpirvAstVisitor::Visit(ShaderAst::IntrinsicExpression& node)
{
switch (node.intrinsic)
{
case ShaderAst::IntrinsicType::DotProduct:
{
ShaderAst::ExpressionType vecExprType = GetExpressionType(*node.parameters[0]);
const ShaderAst::ExpressionType& vecExprType = GetExpressionType(*node.parameters[0]);
assert(IsVectorType(vecExprType));
const ShaderAst::VectorType& vecType = std::get<ShaderAst::VectorType>(vecExprType);
@ -598,6 +661,19 @@ namespace Nz
break;
}
case ShaderAst::IntrinsicType::SampleTexture:
{
UInt32 typeId = m_writer.GetTypeId(ShaderAst::VectorType{4, ShaderAst::PrimitiveType::Float32});
UInt32 samplerId = EvaluateExpression(node.parameters[0]);
UInt32 coordinatesId = EvaluateExpression(node.parameters[1]);
UInt32 resultId = m_writer.AllocateResultId();
m_currentBlock->Append(SpirvOp::OpImageSampleImplicitLod, typeId, resultId, samplerId, coordinatesId);
PushResultId(resultId);
break;
}
case ShaderAst::IntrinsicType::CrossProduct:
default:
throw std::runtime_error("not yet implemented");
@ -609,23 +685,44 @@ namespace Nz
// nothing to do
}
void SpirvAstVisitor::Visit(ShaderAst::ReturnStatement& node)
{
if (node.returnExpr)
m_currentBlock->Append(SpirvOp::OpReturnValue, EvaluateExpression(node.returnExpr));
else
m_currentBlock->Append(SpirvOp::OpReturn);
}
void SpirvAstVisitor::Visit(ShaderAst::MultiStatement& node)
{
for (auto& statement : node.statements)
statement->Visit(*this);
}
void SpirvAstVisitor::Visit(ShaderAst::ReturnStatement& node)
{
if (node.returnExpr)
{
// Handle entry point return
const auto& func = m_funcData[m_funcIndex];
if (func.entryPointData)
{
auto& entryPointData = *func.entryPointData;
if (entryPointData.outputStructTypeId)
{
UInt32 paramId = EvaluateExpression(node.returnExpr);
for (const auto& output : entryPointData.outputs)
{
UInt32 resultId = m_writer.AllocateResultId();
m_currentBlock->Append(SpirvOp::OpCompositeExtract, output.typeId, resultId, paramId, output.memberIndex);
m_currentBlock->Append(SpirvOp::OpStore, output.varId, resultId);
}
}
m_currentBlock->Append(SpirvOp::OpReturn);
}
else
m_currentBlock->Append(SpirvOp::OpReturnValue, EvaluateExpression(node.returnExpr));
}
else
m_currentBlock->Append(SpirvOp::OpReturn);
}
void SpirvAstVisitor::Visit(ShaderAst::SwizzleExpression& node)
{
ShaderAst::ExpressionType targetExprType = GetExpressionType(node);
const ShaderAst::ExpressionType& targetExprType = GetExpressionType(node);
assert(IsPrimitiveType(targetExprType));
ShaderAst::PrimitiveType targetType = std::get<ShaderAst::PrimitiveType>(targetExprType);
@ -658,6 +755,12 @@ namespace Nz
PushResultId(resultId);
}
void SpirvAstVisitor::Visit(ShaderAst::VariableExpression& node)
{
SpirvExpressionLoad loadVisitor(m_writer, *this, *m_currentBlock);
PushResultId(loadVisitor.Evaluate(node));
}
void SpirvAstVisitor::PushResultId(UInt32 value)
{
m_resultIds.push_back(value);

View File

@ -50,11 +50,6 @@ namespace Nz
return Compare(lhs.parameters, rhs.parameters) && Compare(lhs.returnType, rhs.returnType);
}
bool Compare(const Identifier& lhs, const Identifier& rhs) const
{
return lhs.name == rhs.name;
}
bool Compare(const Image& lhs, const Image& rhs) const
{
return lhs.arrayed == rhs.arrayed
@ -114,6 +109,9 @@ namespace Nz
if (lhs.debugName != rhs.debugName)
return false;
if (lhs.funcId != rhs.funcId)
return false;
if (!Compare(lhs.initializer, rhs.initializer))
return false;
@ -231,11 +229,6 @@ namespace Nz
void Register(const Integer&) {}
void Register(const Void&) {}
void Register(const Identifier& identifier)
{
Register(identifier);
}
void Register(const Image& image)
{
Register(image.sampledType);
@ -406,6 +399,7 @@ namespace Nz
tsl::ordered_map<std::variant<AnyConstant, AnyType>, UInt32 /*id*/, AnyHasher, Eq> ids;
tsl::ordered_map<Variable, UInt32 /*id*/, AnyHasher, Eq> variableIds;
tsl::ordered_map<Structure, FieldOffsets /*fieldOffsets*/, AnyHasher, Eq> structureSizes;
StructCallback structCallback;
UInt32& nextResultId;
};
@ -417,132 +411,8 @@ namespace Nz
SpirvConstantCache::SpirvConstantCache(SpirvConstantCache&& cache) noexcept = default;
SpirvConstantCache::~SpirvConstantCache() = default;
UInt32 SpirvConstantCache::GetId(const Constant& c)
{
auto it = m_internal->ids.find(c.constant);
if (it == m_internal->ids.end())
throw std::runtime_error("constant is not registered");
return it->second;
}
UInt32 SpirvConstantCache::GetId(const Type& t)
{
auto it = m_internal->ids.find(t.type);
if (it == m_internal->ids.end())
throw std::runtime_error("constant is not registered");
return it->second;
}
UInt32 SpirvConstantCache::GetId(const Variable& v)
{
auto it = m_internal->variableIds.find(v);
if (it == m_internal->variableIds.end())
throw std::runtime_error("variable is not registered");
return it->second;
}
UInt32 SpirvConstantCache::Register(Constant c)
{
AnyConstant& constant = c.constant;
DepRegisterer registerer(*this);
registerer.Register(constant);
std::size_t h = m_internal->ids.hash_function()(constant);
auto it = m_internal->ids.find(constant, h);
if (it == m_internal->ids.end())
{
UInt32 resultId = m_internal->nextResultId++;
it = m_internal->ids.emplace(std::move(constant), resultId).first;
}
return it.value();
}
UInt32 SpirvConstantCache::Register(Type t)
{
AnyType& type = t.type;
if (std::holds_alternative<Identifier>(type))
{
assert(m_identifierCallback);
return Register(*m_identifierCallback(std::get<Identifier>(type).name));
}
DepRegisterer registerer(*this);
registerer.Register(type);
std::size_t h = m_internal->ids.hash_function()(type);
auto it = m_internal->ids.find(type, h);
if (it == m_internal->ids.end())
{
UInt32 resultId = m_internal->nextResultId++;
it = m_internal->ids.emplace(std::move(type), resultId).first;
}
return it.value();
}
UInt32 SpirvConstantCache::Register(Variable v)
{
DepRegisterer registerer(*this);
registerer.Register(v);
std::size_t h = m_internal->variableIds.hash_function()(v);
auto it = m_internal->variableIds.find(v, h);
if (it == m_internal->variableIds.end())
{
UInt32 resultId = m_internal->nextResultId++;
it = m_internal->variableIds.emplace(std::move(v), resultId).first;
}
return it.value();
}
void SpirvConstantCache::SetIdentifierCallback(IdentifierCallback callback)
{
m_identifierCallback = std::move(callback);
}
void SpirvConstantCache::Write(SpirvSection& annotations, SpirvSection& constants, SpirvSection& debugInfos)
{
for (auto&& [object, id] : m_internal->ids)
{
UInt32 resultId = id;
std::visit(overloaded
{
[&](const AnyConstant& constant) { Write(constant, resultId, constants); },
[&](const AnyType& type) { Write(type, resultId, annotations, constants, debugInfos); },
}, object);
}
for (auto&& [variable, id] : m_internal->variableIds)
{
const auto& var = variable;
UInt32 resultId = id;
if (!variable.debugName.empty())
debugInfos.Append(SpirvOp::OpName, resultId, variable.debugName);
constants.AppendVariadic(SpirvOp::OpVariable, [&](const auto& appender)
{
appender(GetId(*var.type));
appender(resultId);
appender(var.storageClass);
if (var.initializer)
appender(GetId((*var.initializer)->constant));
});
}
}
SpirvConstantCache& SpirvConstantCache::operator=(SpirvConstantCache&& cache) noexcept = default;
auto SpirvConstantCache::BuildConstant(const ShaderConstantValue& value) -> ConstantPtr
auto SpirvConstantCache::BuildConstant(const ShaderAst::ConstantValue& value) const -> ConstantPtr
{
return std::make_shared<Constant>(std::visit([&](auto&& arg) -> SpirvConstantCache::AnyConstant
{
@ -590,7 +460,7 @@ namespace Nz
}, value));
}
auto SpirvConstantCache::BuildFunctionType(const ShaderAst::ExpressionType& retType, const std::vector<ShaderAst::ExpressionType>& parameters) -> TypePtr
auto SpirvConstantCache::BuildFunctionType(const ShaderAst::ExpressionType& retType, const std::vector<ShaderAst::ExpressionType>& parameters) const -> TypePtr
{
std::vector<SpirvConstantCache::TypePtr> parameterTypes;
parameterTypes.reserve(parameters.size());
@ -604,7 +474,7 @@ namespace Nz
});
}
auto SpirvConstantCache::BuildPointerType(const ShaderAst::ExpressionType& type, SpirvStorageClass storageClass) -> TypePtr
auto SpirvConstantCache::BuildPointerType(const ShaderAst::ExpressionType& type, SpirvStorageClass storageClass) const -> TypePtr
{
return std::make_shared<Type>(Pointer{
BuildType(type),
@ -612,7 +482,7 @@ namespace Nz
});
}
auto SpirvConstantCache::BuildPointerType(const ShaderAst::PrimitiveType& type, SpirvStorageClass storageClass) -> TypePtr
auto SpirvConstantCache::BuildPointerType(const ShaderAst::PrimitiveType& type, SpirvStorageClass storageClass) const -> TypePtr
{
return std::make_shared<Type>(Pointer{
BuildType(type),
@ -620,7 +490,7 @@ namespace Nz
});
}
auto SpirvConstantCache::BuildType(const ShaderAst::ExpressionType& type) -> TypePtr
auto SpirvConstantCache::BuildType(const ShaderAst::ExpressionType& type) const -> TypePtr
{
return std::visit([&](auto&& arg) -> TypePtr
{
@ -628,16 +498,13 @@ namespace Nz
}, type);
}
auto SpirvConstantCache::BuildType(const ShaderAst::IdentifierType& type) -> TypePtr
auto SpirvConstantCache::BuildType(const ShaderAst::IdentifierType& type) const -> TypePtr
{
return std::make_shared<Type>(
Identifier{
type.name
}
);
// No IdentifierType is expected (as they should have been resolved by now)
throw std::runtime_error("unexpected identifier");
}
auto SpirvConstantCache::BuildType(const ShaderAst::PrimitiveType& type) -> TypePtr
auto SpirvConstantCache::BuildType(const ShaderAst::PrimitiveType& type) const -> TypePtr
{
return std::make_shared<Type>([&]() -> AnyType
{
@ -657,7 +524,7 @@ namespace Nz
}());
}
auto SpirvConstantCache::BuildType(const ShaderAst::MatrixType& type) -> TypePtr
auto SpirvConstantCache::BuildType(const ShaderAst::MatrixType& type) const -> TypePtr
{
return std::make_shared<Type>(
Matrix{
@ -668,12 +535,12 @@ namespace Nz
});
}
auto SpirvConstantCache::BuildType(const ShaderAst::NoType& type) -> TypePtr
auto SpirvConstantCache::BuildType(const ShaderAst::NoType& type) const -> TypePtr
{
return std::make_shared<Type>(Void{});
}
auto SpirvConstantCache::BuildType(const ShaderAst::SamplerType& type) -> TypePtr
auto SpirvConstantCache::BuildType(const ShaderAst::SamplerType& type) const -> TypePtr
{
//TODO
auto imageType = Image{
@ -690,7 +557,13 @@ namespace Nz
return std::make_shared<Type>(SampledImage{ std::make_shared<Type>(imageType) });
}
auto SpirvConstantCache::BuildType(const ShaderAst::StructDescription& structDesc) -> TypePtr
auto SpirvConstantCache::BuildType(const ShaderAst::StructType& type) const -> TypePtr
{
assert(m_internal->structCallback);
return BuildType(m_internal->structCallback(type.structIndex));
}
auto SpirvConstantCache::BuildType(const ShaderAst::StructDescription& structDesc) const -> TypePtr
{
Structure sType;
sType.name = structDesc.name;
@ -705,11 +578,136 @@ namespace Nz
return std::make_shared<Type>(std::move(sType));
}
auto SpirvConstantCache::BuildType(const ShaderAst::VectorType& type) -> TypePtr
auto SpirvConstantCache::BuildType(const ShaderAst::VectorType& type) const -> TypePtr
{
return std::make_shared<Type>(Vector{ BuildType(type.type), UInt32(type.componentCount) });
}
auto SpirvConstantCache::BuildType(const ShaderAst::UniformType& type) const -> TypePtr
{
assert(std::holds_alternative<ShaderAst::StructType>(type.containedType));
return BuildType(std::get<ShaderAst::StructType>(type.containedType));
}
UInt32 SpirvConstantCache::GetId(const Constant& c)
{
auto it = m_internal->ids.find(c.constant);
if (it == m_internal->ids.end())
throw std::runtime_error("constant is not registered");
return it->second;
}
UInt32 SpirvConstantCache::GetId(const Type& t)
{
auto it = m_internal->ids.find(t.type);
if (it == m_internal->ids.end())
throw std::runtime_error("type is not registered");
return it->second;
}
UInt32 SpirvConstantCache::GetId(const Variable& v)
{
auto it = m_internal->variableIds.find(v);
if (it == m_internal->variableIds.end())
throw std::runtime_error("variable is not registered");
return it->second;
}
UInt32 SpirvConstantCache::Register(Constant c)
{
AnyConstant& constant = c.constant;
DepRegisterer registerer(*this);
registerer.Register(constant);
std::size_t h = m_internal->ids.hash_function()(constant);
auto it = m_internal->ids.find(constant, h);
if (it == m_internal->ids.end())
{
UInt32 resultId = m_internal->nextResultId++;
it = m_internal->ids.emplace(std::move(constant), resultId).first;
}
return it.value();
}
UInt32 SpirvConstantCache::Register(Type t)
{
AnyType& type = t.type;
DepRegisterer registerer(*this);
registerer.Register(type);
std::size_t h = m_internal->ids.hash_function()(type);
auto it = m_internal->ids.find(type, h);
if (it == m_internal->ids.end())
{
UInt32 resultId = m_internal->nextResultId++;
it = m_internal->ids.emplace(std::move(type), resultId).first;
}
return it.value();
}
UInt32 SpirvConstantCache::Register(Variable v)
{
DepRegisterer registerer(*this);
registerer.Register(v);
std::size_t h = m_internal->variableIds.hash_function()(v);
auto it = m_internal->variableIds.find(v, h);
if (it == m_internal->variableIds.end())
{
UInt32 resultId = m_internal->nextResultId++;
it = m_internal->variableIds.emplace(std::move(v), resultId).first;
}
return it.value();
}
void SpirvConstantCache::SetStructCallback(StructCallback callback)
{
m_internal->structCallback = std::move(callback);
}
void SpirvConstantCache::Write(SpirvSection& annotations, SpirvSection& constants, SpirvSection& debugInfos)
{
for (auto&& [object, id] : m_internal->ids)
{
UInt32 resultId = id;
std::visit(overloaded
{
[&](const AnyConstant& constant) { Write(constant, resultId, constants); },
[&](const AnyType& type) { Write(type, resultId, annotations, constants, debugInfos); },
}, object);
}
for (auto&& [variable, id] : m_internal->variableIds)
{
const auto& var = variable;
UInt32 resultId = id;
if (!variable.debugName.empty())
debugInfos.Append(SpirvOp::OpName, resultId, variable.debugName);
constants.AppendVariadic(SpirvOp::OpVariable, [&](const auto& appender)
{
appender(GetId(*var.type));
appender(resultId);
appender(var.storageClass);
if (var.initializer)
appender(GetId((*var.initializer)->constant));
});
}
}
SpirvConstantCache& SpirvConstantCache::operator=(SpirvConstantCache&& cache) noexcept = default;
void SpirvConstantCache::Write(const AnyConstant& constant, UInt32 resultId, SpirvSection& constants)
{
std::visit([&](auto&& arg)

View File

@ -38,6 +38,8 @@ namespace Nz
while (m_currentCodepoint < m_codepointEnd)
{
const UInt32* instructionBegin = m_currentCodepoint;
UInt32 firstWord = ReadWord();
UInt16 wordCount = static_cast<UInt16>((firstWord >> 16) & 0xFFFF);
@ -50,7 +52,7 @@ namespace Nz
if (!HandleOpcode(*inst, wordCount))
break;
m_currentCodepoint += wordCount - 1;
m_currentCodepoint = instructionBegin + wordCount;
}
}

View File

@ -3,7 +3,7 @@
// For conditions of distribution and use, see copyright notice in Config.hpp
#include <Nazara/Shader/SpirvExpressionLoad.hpp>
#include <Nazara/Core/StackVector.hpp>
#include <Nazara/Shader/SpirvAstVisitor.hpp>
#include <Nazara/Shader/SpirvBlock.hpp>
#include <Nazara/Shader/SpirvWriter.hpp>
#include <Nazara/Shader/Debug.hpp>
@ -24,9 +24,8 @@ namespace Nz
{
[this](const Pointer& pointer) -> UInt32
{
UInt32 resultId = m_writer.AllocateResultId();
m_block.Append(SpirvOp::OpLoad, pointer.pointedTypeId, resultId, pointer.resultId);
UInt32 resultId = m_visitor.AllocateResultId();
m_block.Append(SpirvOp::OpLoad, pointer.pointedTypeId, resultId, pointer.pointerId);
return resultId;
},
@ -41,25 +40,26 @@ namespace Nz
}, m_value);
}
/*void SpirvExpressionLoad::Visit(ShaderAst::AccessMemberExpression& node)
void SpirvExpressionLoad::Visit(ShaderAst::AccessMemberIndexExpression& node)
{
Visit(node.structExpr);
node.structExpr->Visit(*this);
const ShaderAst::ExpressionType& exprType = GetExpressionType(node);
UInt32 resultId = m_visitor.AllocateResultId();
UInt32 typeId = m_writer.GetTypeId(exprType);
std::visit(overloaded
{
[&](const Pointer& pointer)
{
ShaderAst::ShaderExpressionType exprType = GetExpressionType(node.structExpr);
UInt32 resultId = m_writer.AllocateResultId();
UInt32 pointerType = m_writer.RegisterPointerType(node.exprType, pointer.storage); //< FIXME
UInt32 typeId = m_writer.GetTypeId(node.exprType);
UInt32 pointerType = m_writer.RegisterPointerType(exprType, pointer.storage); //< FIXME
m_block.AppendVariadic(SpirvOp::OpAccessChain, [&](const auto& appender)
{
appender(pointerType);
appender(resultId);
appender(pointer.resultId);
appender(pointer.pointerId);
for (std::size_t index : node.memberIndices)
appender(m_writer.GetConstantId(Int32(index)));
@ -69,9 +69,6 @@ namespace Nz
},
[&](const Value& value)
{
UInt32 resultId = m_writer.AllocateResultId();
UInt32 typeId = m_writer.GetTypeId(node.exprType);
m_block.AppendVariadic(SpirvOp::OpCompositeExtract, [&](const auto& appender)
{
appender(typeId);
@ -89,15 +86,11 @@ namespace Nz
throw std::runtime_error("an internal error occurred");
}
}, m_value);
}*/
}
void SpirvExpressionLoad::Visit(ShaderAst::IdentifierExpression& node)
void SpirvExpressionLoad::Visit(ShaderAst::VariableExpression& node)
{
if (node.identifier == "d")
m_value = Value{ m_writer.ReadLocalVariable(node.identifier) };
else
m_value = Value{ m_writer.ReadParameterVariable(node.identifier) };
//Visit(node.var);
const auto& var = m_visitor.GetVariable(node.variableId);
m_value = Pointer{ var.storage, var.pointerId, var.pointedTypeId };
}
}

View File

@ -3,6 +3,7 @@
// For conditions of distribution and use, see copyright notice in Config.hpp
#include <Nazara/Shader/SpirvExpressionStore.hpp>
#include <Nazara/Shader/SpirvAstVisitor.hpp>
#include <Nazara/Shader/SpirvBlock.hpp>
#include <Nazara/Shader/SpirvWriter.hpp>
#include <Nazara/Shader/Debug.hpp>
@ -23,11 +24,11 @@ namespace Nz
{
[&](const Pointer& pointer)
{
m_block.Append(SpirvOp::OpStore, pointer.resultId, resultId);
m_block.Append(SpirvOp::OpStore, pointer.pointerId, resultId);
},
[&](const LocalVar& value)
{
m_writer.WriteLocalVariable(value.varName, resultId);
throw std::runtime_error("not yet implemented");
},
[](std::monostate)
{
@ -36,49 +37,50 @@ namespace Nz
}, m_value);
}
/*void SpirvExpressionStore::Visit(ShaderAst::AccessMemberExpression& node)
void SpirvExpressionStore::Visit(ShaderAst::AccessMemberIndexExpression& node)
{
Visit(node.structExpr);
node.structExpr->Visit(*this);
const ShaderAst::ExpressionType& exprType = GetExpressionType(node);
std::visit(overloaded
{
[&](const Pointer& pointer) -> UInt32
[&](const Pointer& pointer)
{
UInt32 resultId = m_writer.AllocateResultId();
UInt32 pointerType = m_writer.RegisterPointerType(node.exprType, pointer.storage); //< FIXME
UInt32 resultId = m_visitor.AllocateResultId();
UInt32 pointerType = m_writer.RegisterPointerType(exprType, pointer.storage); //< FIXME
m_block.AppendVariadic(SpirvOp::OpAccessChain, [&](const auto& appender)
{
appender(pointerType);
appender(resultId);
appender(pointer.resultId);
appender(pointer.pointerId);
for (std::size_t index : node.memberIndices)
appender(m_writer.GetConstantId(Int32(index)));
});
m_value = Pointer{ pointer.storage, resultId };
return resultId;
m_value = Pointer { pointer.storage, resultId };
},
[](const LocalVar& value) -> UInt32
[&](const LocalVar& value)
{
throw std::runtime_error("not yet implemented");
},
[](std::monostate) -> UInt32
[](std::monostate)
{
throw std::runtime_error("an internal error occurred");
}
}, m_value);
}*/
void SpirvExpressionStore::Visit(ShaderAst::IdentifierExpression& node)
{
m_value = LocalVar{ node.identifier };
}
void SpirvExpressionStore::Visit(ShaderAst::SwizzleExpression& node)
{
throw std::runtime_error("not yet implemented");
}
void SpirvExpressionStore::Visit(ShaderAst::VariableExpression& node)
{
const auto& var = m_visitor.GetVariable(node.variableId);
m_value = Pointer{ var.storage, var.pointerId };
}
}

View File

@ -68,7 +68,7 @@ namespace Nz
UInt32 resultId = 0;
std::size_t currentOperand = 0;
const UInt32* endPtr = startPtr + wordCount;
const UInt32* endPtr = startPtr + wordCount - 1;
while (GetCurrentPtr() < endPtr)
{
const SpirvInstruction::Operand* operand = &instruction.operands[currentOperand];
@ -209,7 +209,7 @@ namespace Nz
m_currentState->stream << "\n";
assert(GetCurrentPtr() == startPtr + wordCount);
assert(GetCurrentPtr() == startPtr + wordCount - 1);
return true;
}

View File

@ -12,10 +12,12 @@
#include <Nazara/Shader/SpirvConstantCache.hpp>
#include <Nazara/Shader/SpirvData.hpp>
#include <Nazara/Shader/SpirvSection.hpp>
#include <Nazara/Shader/Ast/TransformVisitor.hpp>
#include <tsl/ordered_map.h>
#include <tsl/ordered_set.h>
#include <SpirV/GLSL.std.450.h>
#include <cassert>
#include <map>
#include <stdexcept>
#include <type_traits>
#include <vector>
@ -25,34 +27,61 @@ namespace Nz
{
namespace
{
//FIXME: Have this only once
std::unordered_map<std::string, ShaderStageType> s_entryPoints = {
{ "frag", ShaderStageType::Fragment },
{ "vert", ShaderStageType::Vertex },
};
struct Builtin
{
const char* debugName;
ShaderStageTypeFlags compatibleStages;
SpirvBuiltIn decoration;
};
std::unordered_map<std::string, Builtin> s_builtinMapping = {
{ "position", { "VertexPosition", ShaderStageType::Vertex, SpirvBuiltIn::Position } }
};
class PreVisitor : public ShaderAst::AstScopedVisitor
{
public:
struct UniformVar
{
std::optional<UInt32> bindingIndex;
UInt32 pointerId;
};
using BuiltinDecoration = std::map<UInt32, SpirvBuiltIn>;
using LocationDecoration = std::map<UInt32, UInt32>;
using ExtInstList = std::unordered_set<std::string>;
using ExtVarContainer = std::unordered_map<std::size_t /*varIndex*/, UniformVar>;
using LocalContainer = std::unordered_set<ShaderAst::ExpressionType>;
using FunctionContainer = std::vector<std::reference_wrapper<ShaderAst::DeclareFunctionStatement>>;
using StructContainer = std::vector<ShaderAst::StructDescription>;
PreVisitor(const SpirvWriter::States& conditions, SpirvConstantCache& constantCache) :
PreVisitor(const SpirvWriter::States& conditions, SpirvConstantCache& constantCache, std::vector<SpirvAstVisitor::FuncData>& funcs) :
m_conditions(conditions),
m_constantCache(constantCache)
m_constantCache(constantCache),
m_externalBlockIndex(0),
m_funcs(funcs)
{
m_constantCache.SetIdentifierCallback([&](const std::string& identifierName)
m_constantCache.SetStructCallback([this](std::size_t structIndex) -> const ShaderAst::StructDescription&
{
const Identifier* identifier = FindIdentifier(identifierName);
if (!identifier)
throw std::runtime_error("invalid identifier " + identifierName);
assert(std::holds_alternative<ShaderAst::StructDescription>(identifier->value));
return SpirvConstantCache::BuildType(std::get<ShaderAst::StructDescription>(identifier->value));
assert(structIndex < declaredStructs.size());
return declaredStructs[structIndex];
});
}
void Visit(ShaderAst::AccessMemberExpression& node) override
void Visit(ShaderAst::AccessMemberIndexExpression& node) override
{
/*for (std::size_t index : node.memberIdentifiers)
m_constantCache.Register(*SpirvConstantCache::BuildConstant(Int32(index)));*/
AstRecursiveVisitor::Visit(node);
for (std::size_t index : node.memberIndices)
m_constantCache.Register(*m_constantCache.BuildConstant(Int32(index)));
m_constantCache.Register(*m_constantCache.BuildType(node.cachedExpressionType.value()));
}
void Visit(ShaderAst::ConditionalExpression& node) override
@ -64,6 +93,8 @@ namespace Nz
Visit(node.truePath);
else
Visit(node.falsePath);*/
m_constantCache.Register(*m_constantCache.BuildType(node.cachedExpressionType.value()));
}
void Visit(ShaderAst::ConditionalStatement& node) override
@ -79,52 +110,189 @@ namespace Nz
{
std::visit([&](auto&& arg)
{
m_constantCache.Register(*SpirvConstantCache::BuildConstant(arg));
m_constantCache.Register(*m_constantCache.BuildConstant(arg));
}, node.value);
AstScopedVisitor::Visit(node);
}
void Visit(ShaderAst::DeclareExternalStatement& node) override
{
assert(node.varIndex);
std::size_t varIndex = *node.varIndex;
for (auto& extVar : node.externalVars)
{
SpirvConstantCache::Variable variable;
variable.debugName = extVar.name;
variable.storageClass = (ShaderAst::IsSamplerType(extVar.type)) ? SpirvStorageClass::UniformConstant : SpirvStorageClass::Uniform;
variable.type = m_constantCache.BuildPointerType(extVar.type, variable.storageClass);
UniformVar& uniformVar = extVars[varIndex++];
uniformVar.pointerId = m_constantCache.Register(variable);
for (const auto& [attributeType, attributeParam] : extVar.attributes)
{
if (attributeType == ShaderAst::AttributeType::Binding)
{
uniformVar.bindingIndex = std::get<long long>(attributeParam);
break;
}
}
}
}
void Visit(ShaderAst::DeclareFunctionStatement& node) override
{
funcs.emplace_back(node);
std::optional<ShaderStageType> entryPointType;
for (auto& attribute : node.attributes)
{
if (attribute.type == ShaderAst::AttributeType::Entry)
{
auto it = s_entryPoints.find(std::get<std::string>(attribute.args));
assert(it != s_entryPoints.end());
std::vector<ShaderAst::ExpressionType> parameterTypes;
for (auto& parameter : node.parameters)
parameterTypes.push_back(parameter.type);
entryPointType = it->second;
break;
}
}
m_constantCache.Register(*SpirvConstantCache::BuildFunctionType(node.returnType, parameterTypes));
assert(node.funcIndex);
std::size_t funcIndex = *node.funcIndex;
if (funcIndex >= m_funcs.size())
m_funcs.resize(funcIndex + 1);
auto& funcData = m_funcs[funcIndex];
funcData.name = node.name;
if (!entryPointType)
{
std::vector<ShaderAst::ExpressionType> parameterTypes;
for (auto& parameter : node.parameters)
parameterTypes.push_back(parameter.type);
funcData.returnTypeId = m_constantCache.Register(*m_constantCache.BuildType(node.returnType));
funcData.funcTypeId = m_constantCache.Register(*m_constantCache.BuildFunctionType(node.returnType, parameterTypes));
for (auto& parameter : node.parameters)
{
auto& funcParam = funcData.parameters.emplace_back();
funcParam.pointerTypeId = m_constantCache.Register(*m_constantCache.BuildPointerType(parameter.type, SpirvStorageClass::Function));
funcParam.typeId = m_constantCache.Register(*m_constantCache.BuildType(parameter.type));
}
}
else
{
using EntryPoint = SpirvAstVisitor::EntryPoint;
funcData.returnTypeId = m_constantCache.Register(*m_constantCache.BuildType(ShaderAst::NoType{}));
funcData.funcTypeId = m_constantCache.Register(*m_constantCache.BuildFunctionType(ShaderAst::NoType{}, {}));
std::optional<EntryPoint::InputStruct> inputStruct;
std::vector<EntryPoint::Input> inputs;
if (!node.parameters.empty())
{
assert(node.parameters.size() == 1);
auto& parameter = node.parameters.front();
assert(std::holds_alternative<ShaderAst::StructType>(parameter.type));
std::size_t structIndex = std::get<ShaderAst::StructType>(parameter.type).structIndex;
const ShaderAst::StructDescription& structDesc = declaredStructs[structIndex];
std::size_t memberIndex = 0;
for (const auto& member : structDesc.members)
{
if (UInt32 varId = HandleEntryInOutType(*entryPointType, funcIndex, member, SpirvStorageClass::Input); varId != 0)
{
inputs.push_back({
m_constantCache.Register(*m_constantCache.BuildConstant(Int32(memberIndex))),
m_constantCache.Register(*m_constantCache.BuildPointerType(member.type, SpirvStorageClass::Function)),
varId
});
}
memberIndex++;
}
inputStruct = EntryPoint::InputStruct{
m_constantCache.Register(*m_constantCache.BuildPointerType(parameter.type, SpirvStorageClass::Function)),
m_constantCache.Register(*m_constantCache.BuildType(parameter.type))
};
}
std::optional<UInt32> outputStructId;
std::vector<EntryPoint::Output> outputs;
if (!IsNoType(node.returnType))
{
assert(std::holds_alternative<ShaderAst::StructType>(node.returnType));
std::size_t structIndex = std::get<ShaderAst::StructType>(node.returnType).structIndex;
const ShaderAst::StructDescription& structDesc = declaredStructs[structIndex];
std::size_t memberIndex = 0;
for (const auto& member : structDesc.members)
{
if (UInt32 varId = HandleEntryInOutType(*entryPointType, funcIndex, member, SpirvStorageClass::Output); varId != 0)
{
outputs.push_back({
Int32(memberIndex),
m_constantCache.Register(*m_constantCache.BuildType(member.type)),
varId
});
}
memberIndex++;
}
outputStructId = m_constantCache.Register(*m_constantCache.BuildType(node.returnType));
}
funcData.entryPointData = EntryPoint{
*entryPointType,
inputStruct,
outputStructId,
funcIndex,
std::move(inputs),
std::move(outputs)
};
}
m_funcIndex = funcIndex;
AstScopedVisitor::Visit(node);
m_funcIndex.reset();
}
void Visit(ShaderAst::DeclareStructStatement& node) override
{
AstScopedVisitor::Visit(node);
SpirvConstantCache::Structure sType;
sType.name = node.description.name;
assert(node.structIndex);
std::size_t structIndex = *node.structIndex;
if (structIndex >= declaredStructs.size())
declaredStructs.resize(structIndex + 1);
for (const auto& [name, attribute, type] : node.description.members)
{
auto& sMembers = sType.members.emplace_back();
sMembers.name = name;
sMembers.type = SpirvConstantCache::BuildType(type);
}
declaredStructs[structIndex] = node.description;
m_constantCache.Register(SpirvConstantCache::Type{ std::move(sType) });
m_constantCache.Register(*m_constantCache.BuildType(node.description));
}
void Visit(ShaderAst::DeclareVariableStatement& node) override
{
AstScopedVisitor::Visit(node);
m_constantCache.Register(*SpirvConstantCache::BuildType(node.varType));
assert(m_funcIndex);
auto& func = m_funcs[*m_funcIndex];
assert(node.varIndex);
func.varIndexToVarId[*node.varIndex] = func.variables.size();
auto& var = func.variables.emplace_back();
var.typeId = m_constantCache.Register(*m_constantCache.BuildPointerType(node.varType, SpirvStorageClass::Function));
}
void Visit(ShaderAst::IdentifierExpression& node) override
{
m_constantCache.Register(*SpirvConstantCache::BuildType(node.cachedExpressionType.value()));
m_constantCache.Register(*m_constantCache.BuildType(node.cachedExpressionType.value()));
AstScopedVisitor::Visit(node);
}
@ -144,40 +312,88 @@ namespace Nz
case ShaderAst::IntrinsicType::DotProduct:
break;
}
m_constantCache.Register(*m_constantCache.BuildType(node.cachedExpressionType.value()));
}
void Visit(ShaderAst::SwizzleExpression& node) override
{
AstScopedVisitor::Visit(node);
m_constantCache.Register(*m_constantCache.BuildType(node.cachedExpressionType.value()));
}
UInt32 HandleEntryInOutType(ShaderStageType entryPointType, std::size_t funcIndex, const ShaderAst::StructDescription::StructMember& member, SpirvStorageClass storageClass)
{
std::optional<std::reference_wrapper<Builtin>> builtinOpt;
std::optional<long long> attributeLocation;
for (const auto& [attributeType, attributeParam] : member.attributes)
{
if (attributeType == ShaderAst::AttributeType::Builtin)
{
auto it = s_builtinMapping.find(std::get<std::string>(attributeParam));
if (it != s_builtinMapping.end())
{
builtinOpt = it->second;
break;
}
}
else if (attributeType == ShaderAst::AttributeType::Location)
{
attributeLocation = std::get<long long>(attributeParam);
break;
}
}
if (builtinOpt)
{
Builtin& builtin = *builtinOpt;
if ((builtin.compatibleStages & entryPointType) == 0)
return 0;
SpirvBuiltIn builtinDecoration = builtin.decoration;
SpirvConstantCache::Variable variable;
variable.debugName = builtin.debugName;
variable.funcId = funcIndex;
variable.storageClass = storageClass;
variable.type = m_constantCache.BuildPointerType(member.type, storageClass);
UInt32 varId = m_constantCache.Register(variable);
builtinDecorations[varId] = builtinDecoration;
return varId;
}
else if (attributeLocation)
{
SpirvConstantCache::Variable variable;
variable.debugName = member.name;
variable.funcId = funcIndex;
variable.storageClass = storageClass;
variable.type = m_constantCache.BuildPointerType(member.type, storageClass);
UInt32 varId = m_constantCache.Register(variable);
locationDecorations[varId] = *attributeLocation;
return varId;
}
return 0;
}
BuiltinDecoration builtinDecorations;
ExtInstList extInsts;
FunctionContainer funcs;
ExtVarContainer extVars;
LocationDecoration locationDecorations;
StructContainer declaredStructs;
private:
const SpirvWriter::States& m_conditions;
SpirvConstantCache& m_constantCache;
std::optional<std::size_t> m_funcIndex;
std::size_t m_externalBlockIndex;
std::vector<SpirvAstVisitor::FuncData>& m_funcs;
};
template<typename T>
constexpr ShaderAst::PrimitiveType GetBasicType()
{
if constexpr (std::is_same_v<T, bool>)
return ShaderAst::PrimitiveType::Boolean;
else if constexpr (std::is_same_v<T, float>)
return(ShaderAst::PrimitiveType::Float32);
else if constexpr (std::is_same_v<T, Int32>)
return(ShaderAst::PrimitiveType::Int32);
else if constexpr (std::is_same_v<T, Vector2f>)
return(ShaderAst::PrimitiveType::Float2);
else if constexpr (std::is_same_v<T, Vector3f>)
return(ShaderAst::PrimitiveType::Float3);
else if constexpr (std::is_same_v<T, Vector4f>)
return(ShaderAst::PrimitiveType::Float4);
else if constexpr (std::is_same_v<T, Vector2i32>)
return(ShaderAst::PrimitiveType::Int2);
else if constexpr (std::is_same_v<T, Vector3i32>)
return(ShaderAst::PrimitiveType::Int3);
else if constexpr (std::is_same_v<T, Vector4i32>)
return(ShaderAst::PrimitiveType::Int4);
else
static_assert(AlwaysFalse<T>::value, "unhandled type");
}
}
struct SpirvWriter::State
@ -194,18 +410,13 @@ namespace Nz
UInt32 id;
};
tsl::ordered_map<std::string, ExtVar> inputIds;
tsl::ordered_map<std::string, ExtVar> outputIds;
tsl::ordered_map<std::string, ExtVar> parameterIds;
tsl::ordered_map<std::string, ExtVar> uniformIds;
std::unordered_map<std::string, UInt32> extensionInstructions;
std::unordered_map<ShaderAst::BuiltinEntry, ExtVar> builtinIds;
std::unordered_map<std::string, UInt32> varToResult;
std::vector<Func> funcs;
std::vector<SpirvBlock> functionBlocks;
std::vector<SpirvAstVisitor::FuncData> funcs;
std::vector<UInt32> resultIds;
UInt32 nextVarIndex = 1;
SpirvConstantCache constantTypeCache; //< init after nextVarIndex
PreVisitor* preVisitor;
// Output
SpirvSection header;
@ -226,6 +437,9 @@ namespace Nz
if (!ShaderAst::ValidateAst(shader, &error))
throw std::runtime_error("Invalid shader AST: " + error);
ShaderAst::TransformVisitor transformVisitor;
ShaderAst::StatementPtr transformedShader = transformVisitor.Transform(shader);
m_context.states = &conditions;
State state;
@ -235,245 +449,37 @@ namespace Nz
m_currentState = nullptr;
});
ShaderAst::AstCloner cloner;
// Register all extended instruction sets
PreVisitor preVisitor(conditions, state.constantTypeCache);
shader->Visit(preVisitor);
PreVisitor preVisitor(conditions, state.constantTypeCache, state.funcs);
transformedShader->Visit(preVisitor);
m_currentState->preVisitor = &preVisitor;
for (const std::string& extInst : preVisitor.extInsts)
state.extensionInstructions[extInst] = AllocateResultId();
// Register all types
/*for (const auto& func : shader.GetFunctions())
{
RegisterType(func.returnType);
for (const auto& param : func.parameters)
RegisterType(param.type);
}
for (const auto& input : shader.GetInputs())
RegisterPointerType(input.type, SpirvStorageClass::Input);
for (const auto& output : shader.GetOutputs())
RegisterPointerType(output.type, SpirvStorageClass::Output);
for (const auto& uniform : shader.GetUniforms())
RegisterPointerType(uniform.type, (IsSamplerType(uniform.type)) ? SpirvStorageClass::UniformConstant : SpirvStorageClass::Uniform);
for (const auto& func : shader.GetFunctions())
RegisterFunctionType(func.returnType, func.parameters);
for (const auto& type : preVisitor.variableTypes)
RegisterType(type);
for (const auto& builtin : preVisitor.builtinVars)
RegisterType(builtin->type);
// Register result id and debug infos for global variables/functions
for (const auto& builtin : preVisitor.builtinVars)
{
SpirvConstantCache::Variable variable;
SpirvBuiltIn builtinDecoration;
switch (builtin->entry)
{
case ShaderAst::BuiltinEntry::VertexPosition:
variable.debugName = "builtin_VertexPosition";
variable.storageClass = SpirvStorageClass::Output;
builtinDecoration = SpirvBuiltIn::Position;
break;
default:
throw std::runtime_error("unexpected builtin type");
}
const ShaderAst::ShaderExpressionType& builtinExprType = builtin->type;
assert(IsBasicType(builtinExprType));
ShaderAst::BasicType builtinType = std::get<ShaderAst::BasicType>(builtinExprType);
variable.type = SpirvConstantCache::BuildPointerType(builtinType, variable.storageClass);
UInt32 varId = m_currentState->constantTypeCache.Register(variable);
ExtVar builtinData;
builtinData.pointerTypeId = GetPointerTypeId(builtinType, variable.storageClass);
builtinData.typeId = GetTypeId(builtinType);
builtinData.varId = varId;
state.annotations.Append(SpirvOp::OpDecorate, builtinData.varId, SpirvDecoration::BuiltIn, builtinDecoration);
state.builtinIds.emplace(builtin->entry, builtinData);
}
for (const auto& input : shader.GetInputs())
{
SpirvConstantCache::Variable variable;
variable.debugName = input.name;
variable.storageClass = SpirvStorageClass::Input;
variable.type = SpirvConstantCache::BuildPointerType(shader, input.type, variable.storageClass);
UInt32 varId = m_currentState->constantTypeCache.Register(variable);
ExtVar inputData;
inputData.pointerTypeId = GetPointerTypeId(input.type, variable.storageClass);
inputData.typeId = GetTypeId(input.type);
inputData.varId = varId;
state.inputIds.emplace(input.name, std::move(inputData));
if (input.locationIndex)
state.annotations.Append(SpirvOp::OpDecorate, varId, SpirvDecoration::Location, *input.locationIndex);
}
for (const auto& output : shader.GetOutputs())
{
SpirvConstantCache::Variable variable;
variable.debugName = output.name;
variable.storageClass = SpirvStorageClass::Output;
variable.type = SpirvConstantCache::BuildPointerType(shader, output.type, variable.storageClass);
UInt32 varId = m_currentState->constantTypeCache.Register(variable);
ExtVar outputData;
outputData.pointerTypeId = GetPointerTypeId(output.type, variable.storageClass);
outputData.typeId = GetTypeId(output.type);
outputData.varId = varId;
state.outputIds.emplace(output.name, std::move(outputData));
if (output.locationIndex)
state.annotations.Append(SpirvOp::OpDecorate, varId, SpirvDecoration::Location, *output.locationIndex);
}
for (const auto& uniform : shader.GetUniforms())
{
SpirvConstantCache::Variable variable;
variable.debugName = uniform.name;
variable.storageClass = (IsSamplerType(uniform.type)) ? SpirvStorageClass::UniformConstant : SpirvStorageClass::Uniform;
variable.type = SpirvConstantCache::BuildPointerType(shader, uniform.type, variable.storageClass);
UInt32 varId = m_currentState->constantTypeCache.Register(variable);
ExtVar uniformData;
uniformData.pointerTypeId = GetPointerTypeId(uniform.type, variable.storageClass);
uniformData.typeId = GetTypeId(uniform.type);
uniformData.varId = varId;
state.uniformIds.emplace(uniform.name, std::move(uniformData));
if (uniform.bindingIndex)
{
state.annotations.Append(SpirvOp::OpDecorate, varId, SpirvDecoration::Binding, *uniform.bindingIndex);
state.annotations.Append(SpirvOp::OpDecorate, varId, SpirvDecoration::DescriptorSet, 0);
}
}*/
for (const ShaderAst::DeclareFunctionStatement& func : preVisitor.funcs)
{
auto& funcData = state.funcs.emplace_back();
funcData.statement = &func;
funcData.id = AllocateResultId();
funcData.typeId = GetFunctionTypeId(func);
state.debugInfo.Append(SpirvOp::OpName, funcData.id, func.name);
}
std::size_t funcIndex = 0;
for (const ShaderAst::DeclareFunctionStatement& func : preVisitor.funcs)
{
auto& funcData = state.funcs[funcIndex++];
state.instructions.Append(SpirvOp::OpFunction, GetTypeId(func.returnType), funcData.id, 0, funcData.typeId);
state.functionBlocks.clear();
state.functionBlocks.emplace_back(*this);
state.parameterIds.clear();
for (const auto& param : func.parameters)
{
UInt32 paramResultId = AllocateResultId();
state.instructions.Append(SpirvOp::OpFunctionParameter, GetTypeId(param.type), paramResultId);
ExtVar parameterData;
parameterData.pointerTypeId = GetPointerTypeId(param.type, SpirvStorageClass::Function);
parameterData.typeId = GetTypeId(param.type);
parameterData.varId = paramResultId;
state.parameterIds.emplace(param.name, std::move(parameterData));
}
SpirvAstVisitor visitor(*this, state.functionBlocks);
for (const auto& statement : func.statements)
statement->Visit(visitor);
if (!state.functionBlocks.back().IsTerminated())
{
assert(func.returnType == ShaderAst::ExpressionType{ ShaderAst::NoType{} });
state.functionBlocks.back().Append(SpirvOp::OpReturn);
}
for (SpirvBlock& block : state.functionBlocks)
state.instructions.AppendSection(block);
state.instructions.Append(SpirvOp::OpFunctionEnd);
}
m_currentState->constantTypeCache.Write(m_currentState->annotations, m_currentState->constants, m_currentState->debugInfo);
SpirvAstVisitor visitor(*this, state.instructions, state.funcs);
transformedShader->Visit(visitor);
AppendHeader();
for (std::size_t i = 0; i < ShaderStageTypeCount; ++i)
for (auto&& [varIndex, extVar] : preVisitor.extVars)
{
/*const ShaderAst::DeclareFunctionStatement* statement = m_context.cache.entryFunctions[i];
if (!statement)
continue;
auto it = std::find_if(state.funcs.begin(), state.funcs.end(), [&](const auto& funcData) { return funcData.statement == statement; });
assert(it != state.funcs.end());
const auto& entryFunc = *it;
SpirvExecutionModel execModel;
ShaderStageType stage = static_cast<ShaderStageType>(i);
switch (stage)
if (extVar.bindingIndex)
{
case ShaderStageType::Fragment:
execModel = SpirvExecutionModel::Fragment;
break;
case ShaderStageType::Vertex:
execModel = SpirvExecutionModel::Vertex;
break;
default:
throw std::runtime_error("not yet implemented");
state.annotations.Append(SpirvOp::OpDecorate, extVar.pointerId, SpirvDecoration::Binding, *extVar.bindingIndex);
state.annotations.Append(SpirvOp::OpDecorate, extVar.pointerId, SpirvDecoration::DescriptorSet, 0);
}
state.header.AppendVariadic(SpirvOp::OpEntryPoint, [&](const auto& appender)
{
appender(execModel);
appender(entryFunc.id);
appender(statement->name);
for (const auto& [name, varData] : state.builtinIds)
appender(varData.varId);
for (const auto& [name, varData] : state.inputIds)
appender(varData.varId);
for (const auto& [name, varData] : state.outputIds)
appender(varData.varId);
});
if (stage == ShaderStageType::Fragment)
state.header.Append(SpirvOp::OpExecutionMode, entryFunc.id, SpirvExecutionMode::OriginUpperLeft);*/
}
for (auto&& [varId, builtin] : preVisitor.builtinDecorations)
state.annotations.Append(SpirvOp::OpDecorate, varId, SpirvDecoration::BuiltIn, builtin);
for (auto&& [varId, location] : preVisitor.locationDecorations)
state.annotations.Append(SpirvOp::OpDecorate, varId, SpirvDecoration::Location, location);
m_currentState->constantTypeCache.Write(m_currentState->annotations, m_currentState->constants, m_currentState->debugInfo);
std::vector<UInt32> ret;
MergeSections(ret, state.header);
MergeSections(ret, state.debugInfo);
@ -511,171 +517,53 @@ namespace Nz
m_currentState->header.Append(SpirvOp::OpExtInstImport, resultId, extInst);
m_currentState->header.Append(SpirvOp::OpMemoryModel, SpirvAddressingModel::Logical, SpirvMemoryModel::GLSL450);
}
UInt32 SpirvWriter::GetConstantId(const ShaderConstantValue& value) const
{
return m_currentState->constantTypeCache.GetId(*SpirvConstantCache::BuildConstant(value));
}
UInt32 SpirvWriter::GetFunctionTypeId(const ShaderAst::DeclareFunctionStatement& functionNode)
{
return m_currentState->constantTypeCache.GetId({ *BuildFunctionType(functionNode) });
}
auto SpirvWriter::GetBuiltinVariable(ShaderAst::BuiltinEntry builtin) const -> const ExtVar&
{
auto it = m_currentState->builtinIds.find(builtin);
assert(it != m_currentState->builtinIds.end());
return it->second;
}
auto SpirvWriter::GetInputVariable(const std::string& name) const -> const ExtVar&
{
auto it = m_currentState->inputIds.find(name);
assert(it != m_currentState->inputIds.end());
return it->second;
}
auto SpirvWriter::GetOutputVariable(const std::string& name) const -> const ExtVar&
{
auto it = m_currentState->outputIds.find(name);
assert(it != m_currentState->outputIds.end());
return it->second;
}
auto SpirvWriter::GetUniformVariable(const std::string& name) const -> const ExtVar&
{
auto it = m_currentState->uniformIds.find(name);
assert(it != m_currentState->uniformIds.end());
return it.value();
}
UInt32 SpirvWriter::GetPointerTypeId(const ShaderAst::ExpressionType& type, SpirvStorageClass storageClass) const
{
return m_currentState->constantTypeCache.GetId(*SpirvConstantCache::BuildPointerType(type, storageClass));
}
UInt32 SpirvWriter::GetTypeId(const ShaderAst::ExpressionType& type) const
{
return m_currentState->constantTypeCache.GetId(*SpirvConstantCache::BuildType(type));
}
UInt32 SpirvWriter::ReadInputVariable(const std::string& name)
{
auto it = m_currentState->inputIds.find(name);
assert(it != m_currentState->inputIds.end());
return ReadVariable(it.value());
}
std::optional<UInt32> SpirvWriter::ReadInputVariable(const std::string& name, OnlyCache)
{
auto it = m_currentState->inputIds.find(name);
assert(it != m_currentState->inputIds.end());
return ReadVariable(it.value(), OnlyCache{});
}
UInt32 SpirvWriter::ReadLocalVariable(const std::string& name)
{
auto it = m_currentState->varToResult.find(name);
assert(it != m_currentState->varToResult.end());
return it->second;
}
std::optional<UInt32> SpirvWriter::ReadLocalVariable(const std::string& name, OnlyCache)
{
auto it = m_currentState->varToResult.find(name);
if (it == m_currentState->varToResult.end())
return {};
return it->second;
}
UInt32 SpirvWriter::ReadParameterVariable(const std::string& name)
{
auto it = m_currentState->parameterIds.find(name);
assert(it != m_currentState->parameterIds.end());
return ReadVariable(it.value());
}
std::optional<UInt32> SpirvWriter::ReadParameterVariable(const std::string& name, OnlyCache)
{
auto it = m_currentState->parameterIds.find(name);
assert(it != m_currentState->parameterIds.end());
return ReadVariable(it.value(), OnlyCache{});
}
UInt32 SpirvWriter::ReadUniformVariable(const std::string& name)
{
auto it = m_currentState->uniformIds.find(name);
assert(it != m_currentState->uniformIds.end());
return ReadVariable(it.value());
}
std::optional<UInt32> SpirvWriter::ReadUniformVariable(const std::string& name, OnlyCache)
{
auto it = m_currentState->uniformIds.find(name);
assert(it != m_currentState->uniformIds.end());
return ReadVariable(it.value(), OnlyCache{});
}
UInt32 SpirvWriter::ReadVariable(ExtVar& var)
{
if (!var.valueId.has_value())
std::optional<UInt32> fragmentFuncId;
for (auto& func : m_currentState->funcs)
{
UInt32 resultId = AllocateResultId();
m_currentState->functionBlocks.back().Append(SpirvOp::OpLoad, var.typeId, resultId, var.varId);
m_currentState->debugInfo.Append(SpirvOp::OpName, func.funcId, func.name);
var.valueId = resultId;
if (func.entryPointData)
{
auto& entryPointData = func.entryPointData.value();
SpirvExecutionModel execModel;
switch (entryPointData.stageType)
{
case ShaderStageType::Fragment:
execModel = SpirvExecutionModel::Fragment;
break;
case ShaderStageType::Vertex:
execModel = SpirvExecutionModel::Vertex;
break;
default:
throw std::runtime_error("not yet implemented");
}
m_currentState->header.AppendVariadic(SpirvOp::OpEntryPoint, [&](const auto& appender)
{
appender(execModel);
appender(func.funcId);
appender(func.name);
for (const auto& input : entryPointData.inputs)
appender(input.varId);
for (const auto& output : entryPointData.outputs)
appender(output.varId);
});
if (entryPointData.stageType == ShaderStageType::Fragment)
fragmentFuncId = func.funcId;
}
}
return var.valueId.value();
}
if (fragmentFuncId)
m_currentState->header.Append(SpirvOp::OpExecutionMode, *fragmentFuncId, SpirvExecutionMode::OriginUpperLeft);
std::optional<UInt32> SpirvWriter::ReadVariable(const ExtVar& var, OnlyCache)
{
if (!var.valueId.has_value())
return {};
return var.valueId.value();
}
UInt32 SpirvWriter::RegisterConstant(const ShaderConstantValue& value)
{
return m_currentState->constantTypeCache.Register(*SpirvConstantCache::BuildConstant(value));
}
UInt32 SpirvWriter::RegisterFunctionType(const ShaderAst::DeclareFunctionStatement& functionNode)
{
return m_currentState->constantTypeCache.Register({ *BuildFunctionType(functionNode) });
}
UInt32 SpirvWriter::RegisterPointerType(ShaderAst::ExpressionType type, SpirvStorageClass storageClass)
{
return m_currentState->constantTypeCache.Register(*SpirvConstantCache::BuildPointerType(type, storageClass));
}
UInt32 SpirvWriter::RegisterType(ShaderAst::ExpressionType type)
{
assert(m_currentState);
return m_currentState->constantTypeCache.Register(*SpirvConstantCache::BuildType(type));
}
void SpirvWriter::WriteLocalVariable(std::string name, UInt32 resultId)
{
assert(m_currentState);
m_currentState->varToResult.insert_or_assign(std::move(name), resultId);
}
SpirvConstantCache::TypePtr SpirvWriter::BuildFunctionType(const ShaderAst::DeclareFunctionStatement& functionNode)
@ -686,7 +574,56 @@ namespace Nz
for (const auto& parameter : functionNode.parameters)
parameterTypes.push_back(parameter.type);
return SpirvConstantCache::BuildFunctionType(functionNode.returnType, parameterTypes);
return m_currentState->constantTypeCache.BuildFunctionType(functionNode.returnType, parameterTypes);
}
UInt32 SpirvWriter::GetConstantId(const ShaderAst::ConstantValue& value) const
{
return m_currentState->constantTypeCache.GetId(*m_currentState->constantTypeCache.BuildConstant(value));
}
UInt32 SpirvWriter::GetExtVarPointerId(std::size_t extVarIndex) const
{
auto it = m_currentState->preVisitor->extVars.find(extVarIndex);
assert(it != m_currentState->preVisitor->extVars.end());
return it->second.pointerId;
}
UInt32 SpirvWriter::GetFunctionTypeId(const ShaderAst::DeclareFunctionStatement& functionNode)
{
return m_currentState->constantTypeCache.GetId({ *BuildFunctionType(functionNode) });
}
UInt32 SpirvWriter::GetPointerTypeId(const ShaderAst::ExpressionType& type, SpirvStorageClass storageClass) const
{
return m_currentState->constantTypeCache.GetId(*m_currentState->constantTypeCache.BuildPointerType(type, storageClass));
}
UInt32 SpirvWriter::GetTypeId(const ShaderAst::ExpressionType& type) const
{
return m_currentState->constantTypeCache.GetId(*m_currentState->constantTypeCache.BuildType(type));
}
UInt32 SpirvWriter::RegisterConstant(const ShaderAst::ConstantValue& value)
{
return m_currentState->constantTypeCache.Register(*m_currentState->constantTypeCache.BuildConstant(value));
}
UInt32 SpirvWriter::RegisterFunctionType(const ShaderAst::DeclareFunctionStatement& functionNode)
{
return m_currentState->constantTypeCache.Register({ *BuildFunctionType(functionNode) });
}
UInt32 SpirvWriter::RegisterPointerType(ShaderAst::ExpressionType type, SpirvStorageClass storageClass)
{
return m_currentState->constantTypeCache.Register(*m_currentState->constantTypeCache.BuildPointerType(type, storageClass));
}
UInt32 SpirvWriter::RegisterType(ShaderAst::ExpressionType type)
{
assert(m_currentState);
return m_currentState->constantTypeCache.Register(*m_currentState->constantTypeCache.BuildType(type));
}
void SpirvWriter::MergeSections(std::vector<UInt32>& output, const SpirvSection& from)