Shader: Rework scope handling
This commit is contained in:
parent
feffcfa6e5
commit
f93a5bbdc1
|
|
@ -32,12 +32,9 @@
|
|||
#include <Nazara/Shader/Config.hpp>
|
||||
#include <Nazara/Shader/GlslWriter.hpp>
|
||||
#include <Nazara/Shader/Shader.hpp>
|
||||
#include <Nazara/Shader/ShaderAstCache.hpp>
|
||||
#include <Nazara/Shader/ShaderAstCloner.hpp>
|
||||
#include <Nazara/Shader/ShaderAstExpressionType.hpp>
|
||||
#include <Nazara/Shader/ShaderAstExpressionVisitor.hpp>
|
||||
#include <Nazara/Shader/ShaderAstExpressionVisitorExcept.hpp>
|
||||
#include <Nazara/Shader/ShaderAstNodes.hpp>
|
||||
#include <Nazara/Shader/ShaderAstOptimizer.hpp>
|
||||
#include <Nazara/Shader/ShaderAstRecursiveVisitor.hpp>
|
||||
#include <Nazara/Shader/ShaderAstSerializer.hpp>
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@
|
|||
|
||||
#include <Nazara/Prerequisites.hpp>
|
||||
#include <Nazara/Shader/Config.hpp>
|
||||
#include <Nazara/Shader/ShaderAstRecursiveVisitor.hpp>
|
||||
#include <Nazara/Shader/ShaderAstScopedVisitor.hpp>
|
||||
#include <Nazara/Shader/ShaderWriter.hpp>
|
||||
#include <set>
|
||||
#include <sstream>
|
||||
|
|
@ -17,7 +17,7 @@
|
|||
|
||||
namespace Nz
|
||||
{
|
||||
class NAZARA_SHADER_API GlslWriter : public ShaderWriter, public ShaderAst::AstRecursiveVisitor
|
||||
class NAZARA_SHADER_API GlslWriter : public ShaderWriter, public ShaderAst::AstScopedVisitor
|
||||
{
|
||||
public:
|
||||
struct Environment;
|
||||
|
|
@ -57,8 +57,8 @@ namespace Nz
|
|||
template<typename T> void Append(const T& param);
|
||||
template<typename T1, typename T2, typename... Args> void Append(const T1& firstParam, const T2& secondParam, Args&&... params);
|
||||
void AppendCommentSection(const std::string& section);
|
||||
void AppendEntryPoint(ShaderStageType shaderStage);
|
||||
void AppendField(std::size_t scopeId, const std::string& structName, const std::string* memberIdentifier, std::size_t remainingMembers);
|
||||
void AppendEntryPoint(ShaderStageType shaderStage, ShaderAst::StatementPtr& shader);
|
||||
void AppendField(const std::string& structName, const std::string* memberIdentifier, std::size_t remainingMembers);
|
||||
void AppendLine(const std::string& txt = {});
|
||||
template<typename... Args> void AppendLine(Args&&... params);
|
||||
|
||||
|
|
|
|||
|
|
@ -1,56 +0,0 @@
|
|||
// Copyright (C) 2020 Jérôme Leclercq
|
||||
// This file is part of the "Nazara Engine - Shader generator"
|
||||
// For conditions of distribution and use, see copyright notice in Config.hpp
|
||||
|
||||
#pragma once
|
||||
|
||||
#ifndef NAZARA_SHADERASTCACHE_HPP
|
||||
#define NAZARA_SHADERASTCACHE_HPP
|
||||
|
||||
#include <Nazara/Prerequisites.hpp>
|
||||
#include <Nazara/Shader/Config.hpp>
|
||||
#include <Nazara/Shader/ShaderAstExpressionType.hpp>
|
||||
#include <Nazara/Utility/Enums.hpp>
|
||||
|
||||
namespace Nz::ShaderAst
|
||||
{
|
||||
struct AstCache
|
||||
{
|
||||
struct Identifier;
|
||||
|
||||
struct Alias
|
||||
{
|
||||
std::variant<ExpressionType> value;
|
||||
};
|
||||
|
||||
struct Variable
|
||||
{
|
||||
ExpressionType type;
|
||||
};
|
||||
|
||||
struct Identifier
|
||||
{
|
||||
std::string name;
|
||||
std::variant<Alias, Variable, StructDescription> value;
|
||||
};
|
||||
|
||||
struct Scope
|
||||
{
|
||||
std::optional<std::size_t> parentScopeIndex;
|
||||
std::vector<Identifier> identifiers;
|
||||
};
|
||||
|
||||
inline void Clear();
|
||||
inline const Identifier* FindIdentifier(std::size_t startingScopeId, const std::string& identifierName) const;
|
||||
inline std::size_t GetScopeId(const Node* node) const;
|
||||
|
||||
std::array<DeclareFunctionStatement*, ShaderStageTypeCount> entryFunctions = {};
|
||||
std::unordered_map<const Expression*, ExpressionType> nodeExpressionType;
|
||||
std::unordered_map<const Node*, std::size_t> scopeIdByNode;
|
||||
std::vector<Scope> scopes;
|
||||
};
|
||||
}
|
||||
|
||||
#include <Nazara/Shader/ShaderAstCache.inl>
|
||||
|
||||
#endif
|
||||
|
|
@ -1,45 +0,0 @@
|
|||
// Copyright (C) 2020 Jérôme Leclercq
|
||||
// This file is part of the "Nazara Engine - Shader generator"
|
||||
// For conditions of distribution and use, see copyright notice in Config.hpp
|
||||
|
||||
#include <Nazara/Shader/ShaderAstCache.hpp>
|
||||
#include <Nazara/Shader/Debug.hpp>
|
||||
|
||||
namespace Nz::ShaderAst
|
||||
{
|
||||
inline void AstCache::Clear()
|
||||
{
|
||||
entryFunctions.fill(nullptr);
|
||||
nodeExpressionType.clear();
|
||||
scopeIdByNode.clear();
|
||||
scopes.clear();
|
||||
}
|
||||
|
||||
inline auto AstCache::FindIdentifier(std::size_t startingScopeId, const std::string& identifierName) const -> const Identifier*
|
||||
{
|
||||
assert(startingScopeId < scopes.size());
|
||||
|
||||
std::optional<std::size_t> scopeId = startingScopeId;
|
||||
do
|
||||
{
|
||||
const auto& scope = scopes[*scopeId];
|
||||
auto it = std::find_if(scope.identifiers.rbegin(), scope.identifiers.rend(), [&](const auto& identifier) { return identifier.name == identifierName; });
|
||||
if (it != scope.identifiers.rend())
|
||||
return &*it;
|
||||
|
||||
scopeId = scope.parentScopeIndex;
|
||||
} while (scopeId);
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
inline std::size_t AstCache::GetScopeId(const Node* node) const
|
||||
{
|
||||
auto it = scopeIdByNode.find(node);
|
||||
assert(it != scopeIdByNode.end());
|
||||
|
||||
return it->second;
|
||||
}
|
||||
}
|
||||
|
||||
#include <Nazara/Shader/DebugOff.hpp>
|
||||
|
|
@ -33,7 +33,7 @@ namespace Nz::ShaderAst
|
|||
ExpressionPtr CloneExpression(ExpressionPtr& expr);
|
||||
StatementPtr CloneStatement(StatementPtr& statement);
|
||||
|
||||
virtual std::unique_ptr<DeclareFunctionStatement> Clone(DeclareFunctionStatement& node);
|
||||
virtual StatementPtr Clone(DeclareFunctionStatement& node);
|
||||
|
||||
using AstExpressionVisitor::Visit;
|
||||
using AstStatementVisitor::Visit;
|
||||
|
|
|
|||
|
|
@ -1,58 +0,0 @@
|
|||
// Copyright (C) 2020 Jérôme Leclercq
|
||||
// This file is part of the "Nazara Engine - Shader generator"
|
||||
// For conditions of distribution and use, see copyright notice in Config.hpp
|
||||
|
||||
#pragma once
|
||||
|
||||
#ifndef NAZARA_SHADERASTEXPRESSIONTYPE_HPP
|
||||
#define NAZARA_SHADERASTEXPRESSIONTYPE_HPP
|
||||
|
||||
#include <Nazara/Prerequisites.hpp>
|
||||
#include <Nazara/Shader/Config.hpp>
|
||||
#include <Nazara/Shader/ShaderAstExpressionVisitor.hpp>
|
||||
#include <Nazara/Shader/Ast/ExpressionType.hpp>
|
||||
#include <vector>
|
||||
|
||||
namespace Nz::ShaderAst
|
||||
{
|
||||
struct AstCache;
|
||||
|
||||
class NAZARA_SHADER_API ExpressionTypeVisitor : public AstExpressionVisitor
|
||||
{
|
||||
public:
|
||||
ExpressionTypeVisitor() = default;
|
||||
ExpressionTypeVisitor(const ExpressionTypeVisitor&) = delete;
|
||||
ExpressionTypeVisitor(ExpressionTypeVisitor&&) = delete;
|
||||
~ExpressionTypeVisitor() = default;
|
||||
|
||||
ExpressionType GetExpressionType(Expression& expression, AstCache* cache);
|
||||
|
||||
ExpressionTypeVisitor& operator=(const ExpressionTypeVisitor&) = delete;
|
||||
ExpressionTypeVisitor& operator=(ExpressionTypeVisitor&&) = delete;
|
||||
|
||||
private:
|
||||
ExpressionType GetExpressionTypeInternal(Expression& expression);
|
||||
ExpressionType ResolveAlias(Expression& expression, ExpressionType expressionType);
|
||||
|
||||
void Visit(Expression& expression);
|
||||
|
||||
void Visit(AccessMemberExpression& node) override;
|
||||
void Visit(AssignExpression& node) override;
|
||||
void Visit(BinaryExpression& node) override;
|
||||
void Visit(CastExpression& node) override;
|
||||
void Visit(ConditionalExpression& node) override;
|
||||
void Visit(ConstantExpression& node) override;
|
||||
void Visit(IdentifierExpression& node) override;
|
||||
void Visit(IntrinsicExpression& node) override;
|
||||
void Visit(SwizzleExpression& node) override;
|
||||
|
||||
AstCache* m_cache;
|
||||
std::optional<ExpressionType> m_lastExpressionType;
|
||||
};
|
||||
|
||||
inline ExpressionType GetExpressionType(Expression& expression, AstCache* cache = nullptr);
|
||||
}
|
||||
|
||||
#include <Nazara/Shader/ShaderAstExpressionType.inl>
|
||||
|
||||
#endif
|
||||
|
|
@ -1,17 +0,0 @@
|
|||
// Copyright (C) 2020 Jérôme Leclercq
|
||||
// This file is part of the "Nazara Engine - Shader generator"
|
||||
// For conditions of distribution and use, see copyright notice in Config.hpp
|
||||
|
||||
#include <Nazara/Shader/ShaderAstExpressionType.hpp>
|
||||
#include <Nazara/Shader/Debug.hpp>
|
||||
|
||||
namespace Nz::ShaderAst
|
||||
{
|
||||
inline ExpressionType GetExpressionType(Expression& expression, AstCache* cache)
|
||||
{
|
||||
ExpressionTypeVisitor visitor;
|
||||
return visitor.GetExpressionType(expression, cache);
|
||||
}
|
||||
}
|
||||
|
||||
#include <Nazara/Shader/DebugOff.hpp>
|
||||
|
|
@ -0,0 +1,68 @@
|
|||
// Copyright (C) 2020 Jérôme Leclercq
|
||||
// This file is part of the "Nazara Engine - Shader generator"
|
||||
// For conditions of distribution and use, see copyright notice in Config.hpp
|
||||
|
||||
#pragma once
|
||||
|
||||
#ifndef NAZARA_SHADER_SCOPED_VISITOR_HPP
|
||||
#define NAZARA_SHADER_SCOPED_VISITOR_HPP
|
||||
|
||||
#include <Nazara/Prerequisites.hpp>
|
||||
#include <Nazara/Shader/Config.hpp>
|
||||
#include <Nazara/Shader/ShaderAstRecursiveVisitor.hpp>
|
||||
|
||||
namespace Nz::ShaderAst
|
||||
{
|
||||
class NAZARA_SHADER_API AstScopedVisitor : public AstRecursiveVisitor
|
||||
{
|
||||
public:
|
||||
struct Identifier;
|
||||
|
||||
AstScopedVisitor() = default;
|
||||
~AstScopedVisitor() = default;
|
||||
|
||||
inline const Identifier* FindIdentifier(const std::string_view& identifierName) const;
|
||||
|
||||
void ScopedVisit(StatementPtr& nodePtr);
|
||||
|
||||
using AstRecursiveVisitor::Visit;
|
||||
void Visit(BranchStatement& node) override;
|
||||
void Visit(ConditionalStatement& node) override;
|
||||
void Visit(DeclareExternalStatement& node) override;
|
||||
void Visit(DeclareFunctionStatement& node) override;
|
||||
void Visit(DeclareStructStatement& node) override;
|
||||
void Visit(DeclareVariableStatement& node) override;
|
||||
void Visit(MultiStatement& node) override;
|
||||
|
||||
struct Alias
|
||||
{
|
||||
std::variant<ExpressionType> value;
|
||||
};
|
||||
|
||||
struct Variable
|
||||
{
|
||||
ExpressionType type;
|
||||
};
|
||||
|
||||
struct Identifier
|
||||
{
|
||||
std::string name;
|
||||
std::variant<Alias, Variable, StructDescription> value;
|
||||
};
|
||||
|
||||
protected:
|
||||
void PushScope();
|
||||
void PopScope();
|
||||
|
||||
inline void RegisterStruct(StructDescription structDesc);
|
||||
inline void RegisterVariable(std::string name, ExpressionType type);
|
||||
|
||||
private:
|
||||
std::vector<Identifier> m_identifiersInScope;
|
||||
std::vector<std::size_t> m_scopeSizes;
|
||||
};
|
||||
}
|
||||
|
||||
#include <Nazara/Shader/ShaderAstScopedVisitor.inl>
|
||||
|
||||
#endif
|
||||
|
|
@ -0,0 +1,38 @@
|
|||
// Copyright (C) 2020 Jérôme Leclercq
|
||||
// This file is part of the "Nazara Engine - Shader generator"
|
||||
// For conditions of distribution and use, see copyright notice in Config.hpp
|
||||
|
||||
#include <Nazara/Shader/ShaderAstScopedVisitor.hpp>
|
||||
#include <Nazara/Shader/Debug.hpp>
|
||||
|
||||
namespace Nz::ShaderAst
|
||||
{
|
||||
inline auto AstScopedVisitor::FindIdentifier(const std::string_view& identifierName) const -> const Identifier*
|
||||
{
|
||||
auto it = std::find_if(m_identifiersInScope.rbegin(), m_identifiersInScope.rend(), [&](const Identifier& identifier) { return identifier.name == identifierName; });
|
||||
if (it == m_identifiersInScope.rend())
|
||||
return nullptr;
|
||||
|
||||
return &*it;
|
||||
}
|
||||
|
||||
inline void AstScopedVisitor::RegisterStruct(StructDescription structDesc)
|
||||
{
|
||||
std::string name = structDesc.name;
|
||||
|
||||
m_identifiersInScope.push_back({
|
||||
std::move(name),
|
||||
std::move(structDesc)
|
||||
});
|
||||
}
|
||||
|
||||
inline void AstScopedVisitor::RegisterVariable(std::string name, ExpressionType type)
|
||||
{
|
||||
m_identifiersInScope.push_back({
|
||||
std::move(name),
|
||||
Variable { std::move(type) }
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
#include <Nazara/Shader/DebugOff.hpp>
|
||||
|
|
@ -9,13 +9,12 @@
|
|||
|
||||
#include <Nazara/Prerequisites.hpp>
|
||||
#include <Nazara/Shader/Config.hpp>
|
||||
#include <Nazara/Shader/ShaderAstCache.hpp>
|
||||
#include <Nazara/Shader/ShaderAstRecursiveVisitor.hpp>
|
||||
#include <Nazara/Shader/ShaderAstScopedVisitor.hpp>
|
||||
#include <Nazara/Utility/Enums.hpp>
|
||||
|
||||
namespace Nz::ShaderAst
|
||||
{
|
||||
class NAZARA_SHADER_API AstValidator : public AstRecursiveVisitor
|
||||
class NAZARA_SHADER_API AstValidator final : public AstScopedVisitor
|
||||
{
|
||||
public:
|
||||
inline AstValidator();
|
||||
|
|
@ -23,28 +22,24 @@ namespace Nz::ShaderAst
|
|||
AstValidator(AstValidator&&) = delete;
|
||||
~AstValidator() = default;
|
||||
|
||||
bool Validate(StatementPtr& node, std::string* error = nullptr, AstCache* cache = nullptr);
|
||||
bool Validate(StatementPtr& node, std::string* error = nullptr);
|
||||
|
||||
private:
|
||||
const ExpressionType& GetExpressionType(Expression& expression);
|
||||
Expression& MandatoryExpr(ExpressionPtr& node);
|
||||
Statement& MandatoryStatement(StatementPtr& node);
|
||||
void TypeMustMatch(ExpressionPtr& left, ExpressionPtr& right);
|
||||
void TypeMustMatch(const ExpressionType& left, const ExpressionType& right);
|
||||
|
||||
ExpressionType CheckField(const std::string& structName, const std::string* memberIdentifier, std::size_t remainingMembers);
|
||||
|
||||
AstCache::Scope& EnterScope();
|
||||
void ExitScope();
|
||||
|
||||
void RegisterExpressionType(Expression& node, ExpressionType expressionType);
|
||||
void RegisterScope(Node& node);
|
||||
const ExpressionType& ResolveAlias(const ExpressionType& expressionType);
|
||||
|
||||
void Visit(AccessMemberExpression& node) override;
|
||||
void Visit(AssignExpression& node) override;
|
||||
void Visit(BinaryExpression& node) override;
|
||||
void Visit(CastExpression& node) override;
|
||||
void Visit(ConditionalExpression& node) override;
|
||||
void Visit(ConstantExpression& node) override;
|
||||
void Visit(ConditionalExpression& node) override;
|
||||
void Visit(IdentifierExpression& node) override;
|
||||
void Visit(IntrinsicExpression& node) override;
|
||||
void Visit(SwizzleExpression& node) override;
|
||||
|
|
@ -54,17 +49,15 @@ namespace Nz::ShaderAst
|
|||
void Visit(DeclareExternalStatement& node) override;
|
||||
void Visit(DeclareFunctionStatement& node) override;
|
||||
void Visit(DeclareStructStatement& node) override;
|
||||
void Visit(DeclareVariableStatement& node) override;
|
||||
void Visit(ExpressionStatement& node) override;
|
||||
void Visit(MultiStatement& node) override;
|
||||
void Visit(ReturnStatement& node) override;
|
||||
|
||||
struct Context;
|
||||
|
||||
Context* m_context;
|
||||
};
|
||||
|
||||
NAZARA_SHADER_API bool ValidateAst(StatementPtr& node, std::string* error = nullptr, AstCache* cache = nullptr);
|
||||
NAZARA_SHADER_API bool ValidateAst(StatementPtr& node, std::string* error = nullptr);
|
||||
}
|
||||
|
||||
#include <Nazara/Shader/ShaderAstValidator.inl>
|
||||
|
|
|
|||
|
|
@ -56,6 +56,8 @@ namespace Nz::ShaderAst
|
|||
|
||||
Expression& operator=(const Expression&) = delete;
|
||||
Expression& operator=(Expression&&) noexcept = default;
|
||||
|
||||
std::optional<ExpressionType> cachedExpressionType;
|
||||
};
|
||||
|
||||
struct NAZARA_SHADER_API AccessMemberExpression : public Expression
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ namespace Nz
|
|||
class NAZARA_SHADER_API SpirvAstVisitor : public ShaderAst::ExpressionVisitorExcept, public ShaderAst::StatementVisitorExcept
|
||||
{
|
||||
public:
|
||||
inline SpirvAstVisitor(SpirvWriter& writer, std::vector<SpirvBlock>& blocks, ShaderAst::AstCache* cache);
|
||||
inline SpirvAstVisitor(SpirvWriter& writer, std::vector<SpirvBlock>& blocks);
|
||||
SpirvAstVisitor(const SpirvAstVisitor&) = delete;
|
||||
SpirvAstVisitor(SpirvAstVisitor&&) = delete;
|
||||
~SpirvAstVisitor() = default;
|
||||
|
|
@ -53,10 +53,10 @@ namespace Nz
|
|||
SpirvAstVisitor& operator=(SpirvAstVisitor&&) = delete;
|
||||
|
||||
private:
|
||||
inline const ShaderAst::ExpressionType& GetExpressionType(ShaderAst::Expression& expr) const;
|
||||
void PushResultId(UInt32 value);
|
||||
UInt32 PopResultId();
|
||||
|
||||
ShaderAst::AstCache* m_cache;
|
||||
std::vector<SpirvBlock>& m_blocks;
|
||||
std::vector<UInt32> m_resultIds;
|
||||
SpirvBlock* m_currentBlock;
|
||||
|
|
|
|||
|
|
@ -7,13 +7,18 @@
|
|||
|
||||
namespace Nz
|
||||
{
|
||||
inline SpirvAstVisitor::SpirvAstVisitor(SpirvWriter& writer, std::vector<SpirvBlock>& blocks, ShaderAst::AstCache* cache) :
|
||||
m_cache(cache),
|
||||
inline SpirvAstVisitor::SpirvAstVisitor(SpirvWriter& writer, std::vector<SpirvBlock>& blocks) :
|
||||
m_blocks(blocks),
|
||||
m_writer(writer)
|
||||
{
|
||||
m_currentBlock = &m_blocks.back();
|
||||
}
|
||||
|
||||
inline const ShaderAst::ExpressionType& SpirvAstVisitor::GetExpressionType(ShaderAst::Expression& expr) const
|
||||
{
|
||||
assert(expr.cachedExpressionType);
|
||||
return expr.cachedExpressionType.value();
|
||||
}
|
||||
}
|
||||
|
||||
#include <Nazara/Shader/DebugOff.hpp>
|
||||
|
|
|
|||
|
|
@ -31,11 +31,14 @@ namespace Nz
|
|||
~SpirvConstantCache();
|
||||
|
||||
struct Constant;
|
||||
struct Identifier;
|
||||
struct Type;
|
||||
|
||||
using ConstantPtr = std::shared_ptr<Constant>;
|
||||
using TypePtr = std::shared_ptr<Type>;
|
||||
|
||||
using IdentifierCallback = std::function<TypePtr(const std::string& identifier)>;
|
||||
|
||||
struct Bool {};
|
||||
|
||||
struct Float
|
||||
|
|
@ -63,6 +66,11 @@ namespace Nz
|
|||
UInt32 columnCount;
|
||||
};
|
||||
|
||||
struct Identifier
|
||||
{
|
||||
std::string name;
|
||||
};
|
||||
|
||||
struct Image
|
||||
{
|
||||
std::optional<SpirvAccessQualifier> qualifier;
|
||||
|
|
@ -104,7 +112,7 @@ namespace Nz
|
|||
std::vector<Member> members;
|
||||
};
|
||||
|
||||
using AnyType = std::variant<Bool, Float, Function, Image, Integer, Matrix, Pointer, SampledImage, Structure, Vector, Void>;
|
||||
using AnyType = std::variant<Bool, Float, Function, Identifier, Image, Integer, Matrix, Pointer, SampledImage, Structure, Vector, Void>;
|
||||
|
||||
struct ConstantBool
|
||||
{
|
||||
|
|
@ -166,6 +174,8 @@ namespace Nz
|
|||
UInt32 Register(Type t);
|
||||
UInt32 Register(Variable v);
|
||||
|
||||
void SetIdentifierCallback(IdentifierCallback callback);
|
||||
|
||||
void Write(SpirvSection& annotations, SpirvSection& constants, SpirvSection& debugInfos);
|
||||
|
||||
SpirvConstantCache& operator=(const SpirvConstantCache& cache) = delete;
|
||||
|
|
@ -181,6 +191,7 @@ namespace Nz
|
|||
static TypePtr BuildType(const ShaderAst::NoType& type);
|
||||
static TypePtr BuildType(const ShaderAst::PrimitiveType& type);
|
||||
static TypePtr BuildType(const ShaderAst::SamplerType& type);
|
||||
static TypePtr BuildType(const ShaderAst::StructDescription& structDesc);
|
||||
static TypePtr BuildType(const ShaderAst::VectorType& type);
|
||||
|
||||
private:
|
||||
|
|
@ -193,6 +204,7 @@ namespace Nz
|
|||
|
||||
void WriteStruct(const Structure& structData, UInt32 resultId, SpirvSection& annotations, SpirvSection& constants, SpirvSection& debugInfos);
|
||||
|
||||
IdentifierCallback m_identifierCallback;
|
||||
std::unique_ptr<Internal> m_internal;
|
||||
};
|
||||
}
|
||||
|
|
|
|||
|
|
@ -8,7 +8,6 @@
|
|||
#include <Nazara/Math/Algorithm.hpp>
|
||||
#include <Nazara/Shader/ShaderBuilder.hpp>
|
||||
#include <Nazara/Shader/ShaderAstCloner.hpp>
|
||||
#include <Nazara/Shader/ShaderAstExpressionType.hpp>
|
||||
#include <Nazara/Shader/ShaderAstUtils.hpp>
|
||||
#include <Nazara/Shader/ShaderAstValidator.hpp>
|
||||
#include <optional>
|
||||
|
|
@ -22,31 +21,87 @@ namespace Nz
|
|||
static const char* flipYUniformName = "_NzFlipValue";
|
||||
static const char* overridenMain = "_NzMain";
|
||||
|
||||
struct AstAdapter : ShaderAst::AstCloner
|
||||
//FIXME: Have this only once
|
||||
std::unordered_map<std::string, ShaderStageType> s_entryPoints = {
|
||||
{ "frag", ShaderStageType::Fragment },
|
||||
{ "vert", ShaderStageType::Vertex },
|
||||
};
|
||||
|
||||
struct PreVisitor : ShaderAst::AstCloner
|
||||
{
|
||||
using AstCloner::Clone;
|
||||
|
||||
std::unique_ptr<ShaderAst::DeclareFunctionStatement> Clone(ShaderAst::DeclareFunctionStatement& node) override
|
||||
ShaderAst::StatementPtr Clone(ShaderAst::DeclareFunctionStatement& node) override
|
||||
{
|
||||
auto clone = AstCloner::Clone(node);
|
||||
if (clone->name == "main")
|
||||
clone->name = "_NzMain";
|
||||
assert(clone->GetType() == ShaderAst::NodeType::DeclareFunctionStatement);
|
||||
|
||||
ShaderAst::DeclareFunctionStatement* func = static_cast<ShaderAst::DeclareFunctionStatement*>(clone.get());
|
||||
|
||||
bool hasEntryPoint = false;
|
||||
|
||||
for (auto& attribute : func->attributes)
|
||||
{
|
||||
if (attribute.type == ShaderAst::AttributeType::Entry)
|
||||
{
|
||||
auto it = s_entryPoints.find(std::get<std::string>(attribute.args));
|
||||
assert(it != s_entryPoints.end());
|
||||
|
||||
if (it->second == selectedEntryPoint)
|
||||
{
|
||||
hasEntryPoint = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!hasEntryPoint)
|
||||
return ShaderBuilder::NoOp();
|
||||
|
||||
entryPoint = func;
|
||||
|
||||
if (func->name == "main")
|
||||
func->name = "_NzMain";
|
||||
|
||||
return clone;
|
||||
}
|
||||
|
||||
void Visit(ShaderAst::DeclareFunctionStatement& node)
|
||||
{
|
||||
if (removedEntryPoints.find(&node) != removedEntryPoints.end())
|
||||
{
|
||||
PushStatement(ShaderBuilder::NoOp());
|
||||
return;
|
||||
}
|
||||
ShaderStageType selectedEntryPoint;
|
||||
ShaderAst::DeclareFunctionStatement* entryPoint = nullptr;
|
||||
};
|
||||
|
||||
AstCloner::Visit(node);
|
||||
struct EntryFuncResolver : ShaderAst::AstScopedVisitor
|
||||
{
|
||||
void Visit(ShaderAst::DeclareFunctionStatement& node) override
|
||||
{
|
||||
|
||||
|
||||
if (&node != entryPoint)
|
||||
return;
|
||||
|
||||
assert(node.parameters.size() == 1);
|
||||
|
||||
const ShaderAst::ExpressionType& inputType = node.parameters.front().type;
|
||||
const ShaderAst::ExpressionType& outputType = node.returnType;
|
||||
|
||||
const Identifier* identifier;
|
||||
|
||||
assert(IsIdentifierType(node.parameters.front().type));
|
||||
identifier = FindIdentifier(std::get<ShaderAst::IdentifierType>(inputType).name);
|
||||
assert(identifier);
|
||||
|
||||
inputIdentifier = *identifier;
|
||||
|
||||
assert(IsIdentifierType(outputType));
|
||||
identifier = FindIdentifier(std::get<ShaderAst::IdentifierType>(outputType).name);
|
||||
assert(identifier);
|
||||
|
||||
outputIdentifier = *identifier;
|
||||
}
|
||||
|
||||
std::unordered_set<ShaderAst::DeclareFunctionStatement*> removedEntryPoints;
|
||||
Identifier inputIdentifier;
|
||||
Identifier outputIdentifier;
|
||||
ShaderAst::DeclareFunctionStatement* entryPoint;
|
||||
};
|
||||
|
||||
struct Builtin
|
||||
|
|
@ -64,7 +119,6 @@ namespace Nz
|
|||
struct GlslWriter::State
|
||||
{
|
||||
const States* states = nullptr;
|
||||
ShaderAst::AstCache cache;
|
||||
ShaderAst::DeclareFunctionStatement* entryFunc = nullptr;
|
||||
std::stringstream stream;
|
||||
unsigned int indentLevel = 0;
|
||||
|
|
@ -86,29 +140,18 @@ namespace Nz
|
|||
});
|
||||
|
||||
std::string error;
|
||||
if (!ShaderAst::ValidateAst(shader, &error, &state.cache))
|
||||
if (!ShaderAst::ValidateAst(shader, &error))
|
||||
throw std::runtime_error("Invalid shader AST: " + error);
|
||||
|
||||
state.entryFunc = state.cache.entryFunctions[UnderlyingCast(shaderStage)];
|
||||
if (!state.entryFunc)
|
||||
PreVisitor previsitor;
|
||||
previsitor.selectedEntryPoint = shaderStage;
|
||||
|
||||
ShaderAst::StatementPtr adaptedShader = previsitor.Clone(shader);
|
||||
|
||||
if (!previsitor.entryPoint)
|
||||
throw std::runtime_error("missing entry point");
|
||||
|
||||
AstAdapter adapter;
|
||||
|
||||
for (ShaderAst::DeclareFunctionStatement* entryFunc : state.cache.entryFunctions)
|
||||
{
|
||||
if (entryFunc != state.entryFunc)
|
||||
adapter.removedEntryPoints.insert(entryFunc);
|
||||
}
|
||||
|
||||
ShaderAst::StatementPtr adaptedShader = adapter.Clone(shader);
|
||||
|
||||
state.cache.Clear();
|
||||
if (!ShaderAst::ValidateAst(adaptedShader, &error, &state.cache))
|
||||
throw std::runtime_error("Internal error:" + error);
|
||||
|
||||
state.entryFunc = state.cache.entryFunctions[UnderlyingCast(shaderStage)];
|
||||
assert(state.entryFunc);
|
||||
state.entryFunc = previsitor.entryPoint;
|
||||
|
||||
unsigned int glslVersion;
|
||||
if (m_environment.glES)
|
||||
|
|
@ -190,10 +233,14 @@ namespace Nz
|
|||
AppendLine();
|
||||
}
|
||||
|
||||
adaptedShader->Visit(*this);
|
||||
PushScope();
|
||||
{
|
||||
adaptedShader->Visit(*this);
|
||||
|
||||
// Append true GLSL entry point
|
||||
AppendEntryPoint(shaderStage);
|
||||
// Append true GLSL entry point
|
||||
AppendEntryPoint(shaderStage, adaptedShader);
|
||||
}
|
||||
PopScope();
|
||||
|
||||
return state.stream.str();
|
||||
}
|
||||
|
|
@ -340,8 +387,12 @@ namespace Nz
|
|||
AppendLine();
|
||||
}
|
||||
|
||||
void GlslWriter::AppendEntryPoint(ShaderStageType shaderStage)
|
||||
void GlslWriter::AppendEntryPoint(ShaderStageType shaderStage, ShaderAst::StatementPtr& shader)
|
||||
{
|
||||
EntryFuncResolver entryResolver;
|
||||
entryResolver.entryPoint = m_currentState->entryFunc;
|
||||
entryResolver.ScopedVisit(shader);
|
||||
|
||||
AppendLine();
|
||||
AppendLine("// Entry point handling");
|
||||
|
||||
|
|
@ -354,15 +405,10 @@ namespace Nz
|
|||
std::vector<InOutField> inputFields;
|
||||
const ShaderAst::StructDescription* inputStruct = nullptr;
|
||||
|
||||
auto HandleInOutStructs = [this, shaderStage](const ShaderAst::ExpressionType& expressionType, std::vector<InOutField>& fields, const char* keyword, const char* fromPrefix, const char* targetPrefix) -> const ShaderAst::StructDescription*
|
||||
auto HandleInOutStructs = [this, shaderStage](const Identifier& identifier, std::vector<InOutField>& fields, const char* keyword, const char* fromPrefix, const char* targetPrefix) -> const ShaderAst::StructDescription*
|
||||
{
|
||||
assert(IsIdentifierType(expressionType));
|
||||
|
||||
const ShaderAst::AstCache::Identifier* identifier = m_currentState->cache.FindIdentifier(0, std::get<ShaderAst::IdentifierType>(expressionType).name);
|
||||
assert(identifier);
|
||||
|
||||
assert(std::holds_alternative<ShaderAst::StructDescription>(identifier->value));
|
||||
const auto& s = std::get<ShaderAst::StructDescription>(identifier->value);
|
||||
assert(std::holds_alternative<ShaderAst::StructDescription>(identifier.value));
|
||||
const auto& s = std::get<ShaderAst::StructDescription>(identifier.value);
|
||||
|
||||
for (const auto& member : s.members)
|
||||
{
|
||||
|
|
@ -426,17 +472,12 @@ namespace Nz
|
|||
};
|
||||
|
||||
if (!m_currentState->entryFunc->parameters.empty())
|
||||
{
|
||||
assert(m_currentState->entryFunc->parameters.size() == 1);
|
||||
const auto& parameter = m_currentState->entryFunc->parameters.front();
|
||||
|
||||
inputStruct = HandleInOutStructs(parameter.type, inputFields, "in", "_nzInput.", "_NzIn_");
|
||||
}
|
||||
inputStruct = HandleInOutStructs(entryResolver.inputIdentifier, inputFields, "in", "_nzInput.", "_NzIn_");
|
||||
|
||||
std::vector<InOutField> outputFields;
|
||||
const ShaderAst::StructDescription* outputStruct = nullptr;
|
||||
if (!IsNoType(m_currentState->entryFunc->returnType))
|
||||
outputStruct = HandleInOutStructs(m_currentState->entryFunc->returnType, outputFields, "out", "_nzOutput.", "_NzOut_");
|
||||
outputStruct = HandleInOutStructs(entryResolver.outputIdentifier, outputFields, "out", "_nzOutput.", "_NzOut_");
|
||||
|
||||
if (shaderStage == ShaderStageType::Vertex && m_environment.flipYPosition)
|
||||
AppendLine("uniform float ", flipYUniformName, ";");
|
||||
|
|
@ -486,12 +527,12 @@ namespace Nz
|
|||
LeaveScope();
|
||||
}
|
||||
|
||||
void GlslWriter::AppendField(std::size_t scopeId, const std::string& structName, const std::string* memberIdentifier, std::size_t remainingMembers)
|
||||
void GlslWriter::AppendField(const std::string& structName, const std::string* memberIdentifier, std::size_t remainingMembers)
|
||||
{
|
||||
Append(".");
|
||||
Append(memberIdentifier[0]);
|
||||
|
||||
const ShaderAst::AstCache::Identifier* identifier = m_currentState->cache.FindIdentifier(scopeId, structName);
|
||||
const Identifier* identifier = FindIdentifier(structName);
|
||||
assert(identifier);
|
||||
|
||||
assert(std::holds_alternative<ShaderAst::StructDescription>(identifier->value));
|
||||
|
|
@ -503,7 +544,7 @@ namespace Nz
|
|||
const auto& member = *memberIt;
|
||||
|
||||
if (remainingMembers > 1)
|
||||
AppendField(scopeId, std::get<ShaderAst::IdentifierType>(member.type).name, memberIdentifier + 1, remainingMembers - 1);
|
||||
AppendField(std::get<ShaderAst::IdentifierType>(member.type).name, memberIdentifier + 1, remainingMembers - 1);
|
||||
}
|
||||
|
||||
void GlslWriter::AppendLine(const std::string& txt)
|
||||
|
|
@ -558,12 +599,10 @@ namespace Nz
|
|||
{
|
||||
Visit(node.structExpr, true);
|
||||
|
||||
const ShaderAst::ExpressionType& exprType = GetExpressionType(*node.structExpr, &m_currentState->cache);
|
||||
const ShaderAst::ExpressionType& exprType = node.structExpr->cachedExpressionType.value();
|
||||
assert(IsIdentifierType(exprType));
|
||||
|
||||
std::size_t scopeId = m_currentState->cache.GetScopeId(&node);
|
||||
|
||||
AppendField(scopeId, std::get<ShaderAst::IdentifierType>(exprType).name, node.memberIdentifiers.data(), node.memberIdentifiers.size());
|
||||
AppendField(std::get<ShaderAst::IdentifierType>(exprType).name, node.memberIdentifiers.data(), node.memberIdentifiers.size());
|
||||
}
|
||||
|
||||
void GlslWriter::Visit(ShaderAst::AssignExpression& node)
|
||||
|
|
@ -593,7 +632,9 @@ namespace Nz
|
|||
AppendLine(")");
|
||||
|
||||
EnterScope();
|
||||
PushScope();
|
||||
statement.statement->Visit(*this);
|
||||
PopScope();
|
||||
LeaveScope();
|
||||
|
||||
first = false;
|
||||
|
|
@ -604,7 +645,9 @@ namespace Nz
|
|||
AppendLine("else");
|
||||
|
||||
EnterScope();
|
||||
PushScope();
|
||||
node.elseStatement->Visit(*this);
|
||||
PopScope();
|
||||
LeaveScope();
|
||||
}
|
||||
}
|
||||
|
|
@ -698,6 +741,8 @@ namespace Nz
|
|||
|
||||
void GlslWriter::Visit(ShaderAst::DeclareExternalStatement& node)
|
||||
{
|
||||
|
||||
|
||||
for (const auto& externalVar : node.externalVars)
|
||||
{
|
||||
std::optional<long long> bindingIndex;
|
||||
|
|
@ -729,7 +774,7 @@ namespace Nz
|
|||
|
||||
EnterScope();
|
||||
{
|
||||
const ShaderAst::AstCache::Identifier* identifier = m_currentState->cache.FindIdentifier(0, std::get<ShaderAst::UniformType>(externalVar.type).containedType.name);
|
||||
const Identifier* identifier = FindIdentifier(std::get<ShaderAst::UniformType>(externalVar.type).containedType.name);
|
||||
assert(identifier);
|
||||
|
||||
assert(std::holds_alternative<ShaderAst::StructDescription>(identifier->value));
|
||||
|
|
@ -780,15 +825,19 @@ namespace Nz
|
|||
Append(")\n");
|
||||
|
||||
EnterScope();
|
||||
PushScope();
|
||||
{
|
||||
for (auto& statement : node.statements)
|
||||
statement->Visit(*this);
|
||||
}
|
||||
PopScope();
|
||||
LeaveScope();
|
||||
}
|
||||
|
||||
void GlslWriter::Visit(ShaderAst::DeclareStructStatement& node)
|
||||
{
|
||||
RegisterStruct(node.description);
|
||||
|
||||
Append("struct ");
|
||||
AppendLine(node.description.name);
|
||||
EnterScope();
|
||||
|
|
@ -813,6 +862,8 @@ namespace Nz
|
|||
|
||||
void GlslWriter::Visit(ShaderAst::DeclareVariableStatement& node)
|
||||
{
|
||||
RegisterVariable(node.varName, node.varType);
|
||||
|
||||
Append(node.varType);
|
||||
Append(" ");
|
||||
Append(node.varName);
|
||||
|
|
@ -871,6 +922,8 @@ namespace Nz
|
|||
|
||||
void GlslWriter::Visit(ShaderAst::MultiStatement& node)
|
||||
{
|
||||
PushScope();
|
||||
|
||||
bool first = true;
|
||||
for (const ShaderAst::StatementPtr& statement : node.statements)
|
||||
{
|
||||
|
|
@ -881,6 +934,8 @@ namespace Nz
|
|||
|
||||
first = false;
|
||||
}
|
||||
|
||||
PopScope();
|
||||
}
|
||||
|
||||
void GlslWriter::Visit(ShaderAst::NoOpStatement& /*node*/)
|
||||
|
|
|
|||
|
|
@ -42,7 +42,7 @@ namespace Nz::ShaderAst
|
|||
return PopStatement();
|
||||
}
|
||||
|
||||
std::unique_ptr<DeclareFunctionStatement> AstCloner::Clone(DeclareFunctionStatement& node)
|
||||
StatementPtr AstCloner::Clone(DeclareFunctionStatement& node)
|
||||
{
|
||||
auto clone = std::make_unique<DeclareFunctionStatement>();
|
||||
clone->attributes = node.attributes;
|
||||
|
|
@ -63,6 +63,8 @@ namespace Nz::ShaderAst
|
|||
clone->memberIdentifiers = node.memberIdentifiers;
|
||||
clone->structExpr = CloneExpression(node.structExpr);
|
||||
|
||||
clone->cachedExpressionType = node.cachedExpressionType;
|
||||
|
||||
PushExpression(std::move(clone));
|
||||
}
|
||||
|
||||
|
|
@ -73,6 +75,8 @@ namespace Nz::ShaderAst
|
|||
clone->left = CloneExpression(node.left);
|
||||
clone->right = CloneExpression(node.right);
|
||||
|
||||
clone->cachedExpressionType = node.cachedExpressionType;
|
||||
|
||||
PushExpression(std::move(clone));
|
||||
}
|
||||
|
||||
|
|
@ -83,6 +87,8 @@ namespace Nz::ShaderAst
|
|||
clone->left = CloneExpression(node.left);
|
||||
clone->right = CloneExpression(node.right);
|
||||
|
||||
clone->cachedExpressionType = node.cachedExpressionType;
|
||||
|
||||
PushExpression(std::move(clone));
|
||||
}
|
||||
|
||||
|
|
@ -100,6 +106,8 @@ namespace Nz::ShaderAst
|
|||
clone->expressions[expressionCount++] = CloneExpression(expr);
|
||||
}
|
||||
|
||||
clone->cachedExpressionType = node.cachedExpressionType;
|
||||
|
||||
PushExpression(std::move(clone));
|
||||
}
|
||||
|
||||
|
|
@ -110,6 +118,8 @@ namespace Nz::ShaderAst
|
|||
clone->falsePath = CloneExpression(node.falsePath);
|
||||
clone->truePath = CloneExpression(node.truePath);
|
||||
|
||||
clone->cachedExpressionType = node.cachedExpressionType;
|
||||
|
||||
PushExpression(std::move(clone));
|
||||
}
|
||||
|
||||
|
|
@ -118,6 +128,8 @@ namespace Nz::ShaderAst
|
|||
auto clone = std::make_unique<ConstantExpression>();
|
||||
clone->value = node.value;
|
||||
|
||||
clone->cachedExpressionType = node.cachedExpressionType;
|
||||
|
||||
PushExpression(std::move(clone));
|
||||
}
|
||||
|
||||
|
|
@ -126,6 +138,8 @@ namespace Nz::ShaderAst
|
|||
auto clone = std::make_unique<IdentifierExpression>();
|
||||
clone->identifier = node.identifier;
|
||||
|
||||
clone->cachedExpressionType = node.cachedExpressionType;
|
||||
|
||||
PushExpression(std::move(clone));
|
||||
}
|
||||
|
||||
|
|
@ -138,6 +152,8 @@ namespace Nz::ShaderAst
|
|||
for (auto& parameter : node.parameters)
|
||||
clone->parameters.push_back(CloneExpression(parameter));
|
||||
|
||||
clone->cachedExpressionType = node.cachedExpressionType;
|
||||
|
||||
PushExpression(std::move(clone));
|
||||
}
|
||||
|
||||
|
|
@ -148,6 +164,8 @@ namespace Nz::ShaderAst
|
|||
clone->components = node.components;
|
||||
clone->expression = CloneExpression(node.expression);
|
||||
|
||||
clone->cachedExpressionType = node.cachedExpressionType;
|
||||
|
||||
PushExpression(std::move(clone));
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,258 +0,0 @@
|
|||
// Copyright (C) 2020 Jérôme Leclercq
|
||||
// This file is part of the "Nazara Engine - Shader generator"
|
||||
// For conditions of distribution and use, see copyright notice in Config.hpp
|
||||
|
||||
#include <Nazara/Shader/ShaderAstExpressionType.hpp>
|
||||
#include <Nazara/Shader/ShaderAstCache.hpp>
|
||||
#include <optional>
|
||||
#include <Nazara/Shader/Debug.hpp>
|
||||
|
||||
namespace Nz::ShaderAst
|
||||
{
|
||||
ExpressionType ExpressionTypeVisitor::GetExpressionType(Expression& expression, AstCache* cache)
|
||||
{
|
||||
m_cache = cache;
|
||||
ExpressionType type = GetExpressionTypeInternal(expression);
|
||||
m_cache = nullptr;
|
||||
|
||||
return type;
|
||||
}
|
||||
|
||||
ExpressionType ExpressionTypeVisitor::GetExpressionTypeInternal(Expression& expression)
|
||||
{
|
||||
m_lastExpressionType.reset();
|
||||
|
||||
Visit(expression);
|
||||
|
||||
assert(m_lastExpressionType.has_value());
|
||||
return std::move(*m_lastExpressionType);
|
||||
}
|
||||
|
||||
ExpressionType ExpressionTypeVisitor::ResolveAlias(Expression& expression, ExpressionType expressionType)
|
||||
{
|
||||
if (IsIdentifierType(expressionType))
|
||||
{
|
||||
auto scopeIt = m_cache->scopeIdByNode.find(&expression);
|
||||
if (scopeIt == m_cache->scopeIdByNode.end())
|
||||
throw std::runtime_error("internal error");
|
||||
|
||||
const AstCache::Identifier* identifier = m_cache->FindIdentifier(scopeIt->second, std::get<IdentifierType>(expressionType).name);
|
||||
if (identifier && std::holds_alternative<AstCache::Alias>(identifier->value))
|
||||
{
|
||||
const AstCache::Alias& alias = std::get<AstCache::Alias>(identifier->value);
|
||||
return std::visit([&](auto&& arg) -> ShaderAst::ExpressionType
|
||||
{
|
||||
using T = std::decay_t<decltype(arg)>;
|
||||
|
||||
if constexpr (std::is_same_v<T, ExpressionType>)
|
||||
return arg;
|
||||
else
|
||||
static_assert(AlwaysFalse<T>::value, "non-exhaustive visitor");
|
||||
}, alias.value);
|
||||
}
|
||||
}
|
||||
|
||||
return expressionType;
|
||||
}
|
||||
|
||||
void ExpressionTypeVisitor::Visit(Expression& expression)
|
||||
{
|
||||
if (m_cache)
|
||||
{
|
||||
auto it = m_cache->nodeExpressionType.find(&expression);
|
||||
if (it != m_cache->nodeExpressionType.end())
|
||||
{
|
||||
m_lastExpressionType = it->second;
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
expression.Visit(*this);
|
||||
|
||||
if (m_cache)
|
||||
{
|
||||
assert(m_lastExpressionType.has_value());
|
||||
m_cache->nodeExpressionType.emplace(&expression, *m_lastExpressionType);
|
||||
}
|
||||
}
|
||||
|
||||
void ExpressionTypeVisitor::Visit(AccessMemberExpression& node)
|
||||
{
|
||||
auto scopeIt = m_cache->scopeIdByNode.find(&node);
|
||||
if (scopeIt == m_cache->scopeIdByNode.end())
|
||||
throw std::runtime_error("internal error");
|
||||
|
||||
ExpressionType expressionType = ResolveAlias(node, GetExpressionTypeInternal(*node.structExpr));
|
||||
if (!IsIdentifierType(expressionType))
|
||||
throw std::runtime_error("internal error");
|
||||
|
||||
const AstCache::Identifier* identifier = m_cache->FindIdentifier(scopeIt->second, std::get<IdentifierType>(expressionType).name);
|
||||
|
||||
throw std::runtime_error("unhandled accessmember expression");
|
||||
}
|
||||
|
||||
void ExpressionTypeVisitor::Visit(AssignExpression& node)
|
||||
{
|
||||
Visit(*node.left);
|
||||
}
|
||||
|
||||
void ExpressionTypeVisitor::Visit(BinaryExpression& node)
|
||||
{
|
||||
switch (node.op)
|
||||
{
|
||||
case BinaryType::Add:
|
||||
case BinaryType::Subtract:
|
||||
return Visit(*node.left);
|
||||
|
||||
case BinaryType::Divide:
|
||||
case BinaryType::Multiply:
|
||||
{
|
||||
ExpressionType leftExprType = ResolveAlias(node, GetExpressionTypeInternal(*node.left));
|
||||
ExpressionType rightExprType = ResolveAlias(node, GetExpressionTypeInternal(*node.right));
|
||||
|
||||
if (IsPrimitiveType(leftExprType))
|
||||
{
|
||||
switch (std::get<PrimitiveType>(leftExprType))
|
||||
{
|
||||
case PrimitiveType::Boolean:
|
||||
m_lastExpressionType = std::move(leftExprType);
|
||||
break;
|
||||
|
||||
case PrimitiveType::Float32:
|
||||
case PrimitiveType::Int32:
|
||||
case PrimitiveType::UInt32:
|
||||
m_lastExpressionType = std::move(rightExprType);
|
||||
break;
|
||||
}
|
||||
}
|
||||
else if (IsMatrixType(leftExprType))
|
||||
{
|
||||
if (IsVectorType(rightExprType))
|
||||
m_lastExpressionType = std::move(rightExprType);
|
||||
else
|
||||
m_lastExpressionType = std::move(leftExprType);
|
||||
}
|
||||
else if (IsVectorType(leftExprType))
|
||||
m_lastExpressionType = std::move(leftExprType);
|
||||
else
|
||||
throw std::runtime_error("validation failure");
|
||||
|
||||
break;
|
||||
}
|
||||
|
||||
case BinaryType::CompEq:
|
||||
case BinaryType::CompGe:
|
||||
case BinaryType::CompGt:
|
||||
case BinaryType::CompLe:
|
||||
case BinaryType::CompLt:
|
||||
case BinaryType::CompNe:
|
||||
m_lastExpressionType = PrimitiveType::Boolean;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
void ExpressionTypeVisitor::Visit(CastExpression& node)
|
||||
{
|
||||
m_lastExpressionType = node.targetType;
|
||||
}
|
||||
|
||||
void ExpressionTypeVisitor::Visit(ConditionalExpression& node)
|
||||
{
|
||||
ExpressionType leftExprType = ResolveAlias(node, GetExpressionTypeInternal(*node.truePath));
|
||||
assert(leftExprType == ResolveAlias(node, GetExpressionTypeInternal(*node.falsePath)));
|
||||
|
||||
m_lastExpressionType = std::move(leftExprType);
|
||||
}
|
||||
|
||||
void ExpressionTypeVisitor::Visit(ConstantExpression& node)
|
||||
{
|
||||
m_lastExpressionType = std::visit([&](auto&& arg) -> ShaderAst::ExpressionType
|
||||
{
|
||||
using T = std::decay_t<decltype(arg)>;
|
||||
|
||||
if constexpr (std::is_same_v<T, bool>)
|
||||
return PrimitiveType::Boolean;
|
||||
else if constexpr (std::is_same_v<T, float>)
|
||||
return PrimitiveType::Float32;
|
||||
else if constexpr (std::is_same_v<T, Int32>)
|
||||
return PrimitiveType::Int32;
|
||||
else if constexpr (std::is_same_v<T, UInt32>)
|
||||
return PrimitiveType::UInt32;
|
||||
else if constexpr (std::is_same_v<T, Vector2f>)
|
||||
return VectorType{ 2, PrimitiveType::Float32 };
|
||||
else if constexpr (std::is_same_v<T, Vector3f>)
|
||||
return VectorType{ 3, PrimitiveType::Float32 };
|
||||
else if constexpr (std::is_same_v<T, Vector4f>)
|
||||
return VectorType{ 4, PrimitiveType::Float32 };
|
||||
else if constexpr (std::is_same_v<T, Vector2i32>)
|
||||
return VectorType{ 2, PrimitiveType::Int32 };
|
||||
else if constexpr (std::is_same_v<T, Vector3i32>)
|
||||
return VectorType{ 3, PrimitiveType::Int32 };
|
||||
else if constexpr (std::is_same_v<T, Vector4i32>)
|
||||
return VectorType{ 4, PrimitiveType::Int32 };
|
||||
else
|
||||
static_assert(AlwaysFalse<T>::value, "non-exhaustive visitor");
|
||||
}, node.value);
|
||||
}
|
||||
|
||||
void ExpressionTypeVisitor::Visit(IdentifierExpression& node)
|
||||
{
|
||||
assert(m_cache);
|
||||
|
||||
auto scopeIt = m_cache->scopeIdByNode.find(&node);
|
||||
if (scopeIt == m_cache->scopeIdByNode.end())
|
||||
throw std::runtime_error("internal error");
|
||||
|
||||
const AstCache::Identifier* identifier = m_cache->FindIdentifier(scopeIt->second, node.identifier);
|
||||
if (!identifier || !std::holds_alternative<AstCache::Variable>(identifier->value))
|
||||
throw std::runtime_error("internal error");
|
||||
|
||||
m_lastExpressionType = ResolveAlias(node, std::get<AstCache::Variable>(identifier->value).type);
|
||||
}
|
||||
|
||||
void ExpressionTypeVisitor::Visit(IntrinsicExpression& node)
|
||||
{
|
||||
switch (node.intrinsic)
|
||||
{
|
||||
case IntrinsicType::CrossProduct:
|
||||
Visit(*node.parameters.front());
|
||||
break;
|
||||
|
||||
case IntrinsicType::DotProduct:
|
||||
m_lastExpressionType = PrimitiveType::Float32;
|
||||
break;
|
||||
|
||||
case IntrinsicType::SampleTexture:
|
||||
{
|
||||
if (node.parameters.empty())
|
||||
throw std::runtime_error("validation failure");
|
||||
|
||||
ExpressionType firstParamType = ResolveAlias(node, GetExpressionTypeInternal(*node.parameters.front()));
|
||||
|
||||
if (!IsSamplerType(firstParamType))
|
||||
throw std::runtime_error("validation failure");
|
||||
|
||||
const auto& sampler = std::get<SamplerType>(firstParamType);
|
||||
|
||||
m_lastExpressionType = VectorType{
|
||||
4,
|
||||
sampler.sampledType
|
||||
};
|
||||
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ExpressionTypeVisitor::Visit(SwizzleExpression& node)
|
||||
{
|
||||
ExpressionType exprType = GetExpressionTypeInternal(*node.expression);
|
||||
|
||||
if (IsMatrixType(exprType))
|
||||
m_lastExpressionType = std::get<MatrixType>(exprType).type;
|
||||
else if (IsVectorType(exprType))
|
||||
m_lastExpressionType = std::get<VectorType>(exprType).type;
|
||||
else
|
||||
throw std::runtime_error("validation failure");
|
||||
}
|
||||
}
|
||||
|
|
@ -4,7 +4,6 @@
|
|||
|
||||
#include <Nazara/Shader/ShaderAstOptimizer.hpp>
|
||||
#include <Nazara/Shader/ShaderBuilder.hpp>
|
||||
#include <Nazara/Shader/ShaderAstExpressionType.hpp>
|
||||
#include <cassert>
|
||||
#include <stdexcept>
|
||||
#include <Nazara/Shader/Debug.hpp>
|
||||
|
|
@ -453,8 +452,11 @@ namespace Nz::ShaderAst
|
|||
{
|
||||
auto& constant = static_cast<ConstantExpression&>(*cond);
|
||||
|
||||
assert(IsPrimitiveType(GetExpressionType(constant)));
|
||||
assert(std::get<PrimitiveType>(GetExpressionType(constant)) == PrimitiveType::Boolean);
|
||||
assert(constant.cachedExpressionType);
|
||||
const ExpressionType& constantType = constant.cachedExpressionType.value();
|
||||
|
||||
assert(IsPrimitiveType(constantType));
|
||||
assert(std::get<PrimitiveType>(constantType) == PrimitiveType::Boolean);
|
||||
|
||||
bool cValue = std::get<bool>(constant.value);
|
||||
if (!cValue)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,110 @@
|
|||
// Copyright (C) 2020 Jérôme Leclercq
|
||||
// This file is part of the "Nazara Engine - Shader generator"
|
||||
// For conditions of distribution and use, see copyright notice in Config.hpp
|
||||
|
||||
#include <Nazara/Shader/ShaderAstScopedVisitor.hpp>
|
||||
#include <Nazara/Shader/Debug.hpp>
|
||||
|
||||
namespace Nz::ShaderAst
|
||||
{
|
||||
void AstScopedVisitor::ScopedVisit(StatementPtr& nodePtr)
|
||||
{
|
||||
PushScope(); //< Global scope
|
||||
{
|
||||
nodePtr->Visit(*this);
|
||||
}
|
||||
PopScope();
|
||||
}
|
||||
|
||||
void AstScopedVisitor::Visit(BranchStatement& node)
|
||||
{
|
||||
for (auto& cond : node.condStatements)
|
||||
{
|
||||
PushScope();
|
||||
{
|
||||
cond.condition->Visit(*this);
|
||||
cond.statement->Visit(*this);
|
||||
}
|
||||
PopScope();
|
||||
}
|
||||
|
||||
if (node.elseStatement)
|
||||
{
|
||||
PushScope();
|
||||
{
|
||||
node.elseStatement->Visit(*this);
|
||||
}
|
||||
PopScope();
|
||||
}
|
||||
}
|
||||
|
||||
void AstScopedVisitor::Visit(ConditionalStatement& node)
|
||||
{
|
||||
PushScope();
|
||||
{
|
||||
AstRecursiveVisitor::Visit(node);
|
||||
}
|
||||
PopScope();
|
||||
}
|
||||
|
||||
void AstScopedVisitor::Visit(DeclareExternalStatement& node)
|
||||
{
|
||||
for (auto& extVar : node.externalVars)
|
||||
{
|
||||
ExpressionType subType = extVar.type;
|
||||
if (IsUniformType(subType))
|
||||
subType = IdentifierType{ std::get<UniformType>(subType).containedType };
|
||||
|
||||
RegisterVariable(extVar.name, std::move(subType));
|
||||
}
|
||||
|
||||
AstRecursiveVisitor::Visit(node);
|
||||
}
|
||||
|
||||
void AstScopedVisitor::Visit(DeclareFunctionStatement& node)
|
||||
{
|
||||
PushScope();
|
||||
{
|
||||
for (auto& parameter : node.parameters)
|
||||
RegisterVariable(parameter.name, parameter.type);
|
||||
|
||||
AstRecursiveVisitor::Visit(node);
|
||||
}
|
||||
PopScope();
|
||||
}
|
||||
|
||||
void AstScopedVisitor::Visit(DeclareStructStatement& node)
|
||||
{
|
||||
RegisterStruct(node.description);
|
||||
|
||||
AstRecursiveVisitor::Visit(node);
|
||||
}
|
||||
|
||||
void AstScopedVisitor::Visit(DeclareVariableStatement& node)
|
||||
{
|
||||
RegisterVariable(node.varName, node.varType);
|
||||
|
||||
AstRecursiveVisitor::Visit(node);
|
||||
}
|
||||
|
||||
void AstScopedVisitor::Visit(MultiStatement& node)
|
||||
{
|
||||
PushScope();
|
||||
{
|
||||
AstRecursiveVisitor::Visit(node);
|
||||
}
|
||||
PopScope();
|
||||
}
|
||||
|
||||
void AstScopedVisitor::PushScope()
|
||||
{
|
||||
m_scopeSizes.push_back(m_identifiersInScope.size());
|
||||
}
|
||||
|
||||
void AstScopedVisitor::PopScope()
|
||||
{
|
||||
assert(!m_scopeSizes.empty());
|
||||
m_identifiersInScope.resize(m_scopeSizes.back());
|
||||
m_scopeSizes.pop_back();
|
||||
}
|
||||
}
|
||||
|
|
@ -5,7 +5,6 @@
|
|||
#include <Nazara/Shader/ShaderAstValidator.hpp>
|
||||
#include <Nazara/Core/CallOnExit.hpp>
|
||||
#include <Nazara/Shader/ShaderAstUtils.hpp>
|
||||
#include <Nazara/Shader/ShaderAstExpressionType.hpp>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
#include <Nazara/Shader/Debug.hpp>
|
||||
|
|
@ -27,29 +26,21 @@ namespace Nz::ShaderAst
|
|||
|
||||
struct AstValidator::Context
|
||||
{
|
||||
//const ShaderAst::Function* currentFunction;
|
||||
std::optional<std::size_t> activeScopeId;
|
||||
std::array<DeclareFunctionStatement*, ShaderStageTypeCount> entryFunctions = {};
|
||||
std::unordered_set<std::string> declaredExternalVar;
|
||||
std::unordered_set<long long> usedBindingIndexes;;
|
||||
AstCache* cache;
|
||||
std::unordered_set<long long> usedBindingIndexes;
|
||||
};
|
||||
|
||||
bool AstValidator::Validate(StatementPtr& node, std::string* error, AstCache* cache)
|
||||
bool AstValidator::Validate(StatementPtr& node, std::string* error)
|
||||
{
|
||||
try
|
||||
{
|
||||
AstCache dummy;
|
||||
|
||||
Context currentContext;
|
||||
currentContext.cache = (cache) ? cache : &dummy;
|
||||
|
||||
m_context = ¤tContext;
|
||||
CallOnExit resetContext([&] { m_context = nullptr; });
|
||||
|
||||
EnterScope();
|
||||
node->Visit(*this);
|
||||
ExitScope();
|
||||
|
||||
ScopedVisit(node);
|
||||
return true;
|
||||
}
|
||||
catch (const AstError& e)
|
||||
|
|
@ -61,6 +52,12 @@ namespace Nz::ShaderAst
|
|||
}
|
||||
}
|
||||
|
||||
const ExpressionType& AstValidator::GetExpressionType(Expression& expression)
|
||||
{
|
||||
assert(expression.cachedExpressionType);
|
||||
return ResolveAlias(expression.cachedExpressionType.value());
|
||||
}
|
||||
|
||||
Expression& AstValidator::MandatoryExpr(ExpressionPtr& node)
|
||||
{
|
||||
if (!node)
|
||||
|
|
@ -79,7 +76,7 @@ namespace Nz::ShaderAst
|
|||
|
||||
void AstValidator::TypeMustMatch(ExpressionPtr& left, ExpressionPtr& right)
|
||||
{
|
||||
return TypeMustMatch(GetExpressionType(*left, m_context->cache), GetExpressionType(*right, m_context->cache));
|
||||
return TypeMustMatch(GetExpressionType(*left), GetExpressionType(*right));
|
||||
}
|
||||
|
||||
void AstValidator::TypeMustMatch(const ExpressionType& left, const ExpressionType& right)
|
||||
|
|
@ -90,7 +87,7 @@ namespace Nz::ShaderAst
|
|||
|
||||
ExpressionType AstValidator::CheckField(const std::string& structName, const std::string* memberIdentifier, std::size_t remainingMembers)
|
||||
{
|
||||
const AstCache::Identifier* identifier = m_context->cache->FindIdentifier(*m_context->activeScopeId, structName);
|
||||
const Identifier* identifier = FindIdentifier(structName);
|
||||
if (!identifier)
|
||||
throw AstError{ "unknown identifier " + structName };
|
||||
|
||||
|
|
@ -111,81 +108,69 @@ namespace Nz::ShaderAst
|
|||
return member.type;
|
||||
}
|
||||
|
||||
AstCache::Scope& AstValidator::EnterScope()
|
||||
const ExpressionType& AstValidator::ResolveAlias(const ExpressionType& expressionType)
|
||||
{
|
||||
std::size_t newScopeId = m_context->cache->scopes.size();
|
||||
if (!IsIdentifierType(expressionType))
|
||||
return expressionType;
|
||||
|
||||
std::optional<std::size_t> previousScope = m_context->activeScopeId;
|
||||
const Identifier* identifier = FindIdentifier(std::get<IdentifierType>(expressionType).name);
|
||||
if (identifier && std::holds_alternative<Alias>(identifier->value))
|
||||
{
|
||||
const Alias& alias = std::get<Alias>(identifier->value);
|
||||
return std::visit([&](auto&& arg) -> const ShaderAst::ExpressionType&
|
||||
{
|
||||
using T = std::decay_t<decltype(arg)>;
|
||||
|
||||
auto& newScope = m_context->cache->scopes.emplace_back();
|
||||
newScope.parentScopeIndex = previousScope;
|
||||
if constexpr (std::is_same_v<T, ExpressionType>)
|
||||
return arg;
|
||||
else
|
||||
static_assert(AlwaysFalse<T>::value, "non-exhaustive visitor");
|
||||
}, alias.value);
|
||||
}
|
||||
|
||||
m_context->activeScopeId = newScopeId;
|
||||
return m_context->cache->scopes[newScopeId];
|
||||
}
|
||||
|
||||
void AstValidator::ExitScope()
|
||||
{
|
||||
assert(m_context->activeScopeId);
|
||||
auto& previousScope = m_context->cache->scopes[*m_context->activeScopeId];
|
||||
m_context->activeScopeId = previousScope.parentScopeIndex;
|
||||
}
|
||||
|
||||
void AstValidator::RegisterExpressionType(Expression& node, ExpressionType expressionType)
|
||||
{
|
||||
m_context->cache->nodeExpressionType[&node] = std::move(expressionType);
|
||||
}
|
||||
|
||||
void AstValidator::RegisterScope(Node& node)
|
||||
{
|
||||
if (m_context->activeScopeId)
|
||||
m_context->cache->scopeIdByNode[&node] = *m_context->activeScopeId;
|
||||
return expressionType;
|
||||
}
|
||||
|
||||
void AstValidator::Visit(AccessMemberExpression& node)
|
||||
{
|
||||
RegisterScope(node);
|
||||
|
||||
// Register expressions types
|
||||
AstRecursiveVisitor::Visit(node);
|
||||
AstScopedVisitor::Visit(node);
|
||||
|
||||
ExpressionType exprType = GetExpressionType(MandatoryExpr(node.structExpr), m_context->cache);
|
||||
ExpressionType exprType = GetExpressionType(MandatoryExpr(node.structExpr));
|
||||
if (!IsIdentifierType(exprType))
|
||||
throw AstError{ "expression is not a structure" };
|
||||
|
||||
const std::string& structName = std::get<IdentifierType>(exprType).name;
|
||||
|
||||
RegisterExpressionType(node, CheckField(structName, node.memberIdentifiers.data(), node.memberIdentifiers.size()));
|
||||
node.cachedExpressionType = CheckField(structName, node.memberIdentifiers.data(), node.memberIdentifiers.size());
|
||||
}
|
||||
|
||||
void AstValidator::Visit(AssignExpression& node)
|
||||
{
|
||||
RegisterScope(node);
|
||||
|
||||
MandatoryExpr(node.left);
|
||||
MandatoryExpr(node.right);
|
||||
|
||||
// Register expressions types
|
||||
AstRecursiveVisitor::Visit(node);
|
||||
AstScopedVisitor::Visit(node);
|
||||
|
||||
TypeMustMatch(node.left, node.right);
|
||||
|
||||
if (GetExpressionCategory(*node.left) != ExpressionCategory::LValue)
|
||||
throw AstError { "Assignation is only possible with a l-value" };
|
||||
|
||||
node.cachedExpressionType = GetExpressionType(*node.right);
|
||||
}
|
||||
|
||||
void AstValidator::Visit(BinaryExpression& node)
|
||||
{
|
||||
RegisterScope(node);
|
||||
|
||||
// Register expression type
|
||||
AstRecursiveVisitor::Visit(node);
|
||||
AstScopedVisitor::Visit(node);
|
||||
|
||||
ExpressionType leftExprType = GetExpressionType(MandatoryExpr(node.left), m_context->cache);
|
||||
ExpressionType leftExprType = GetExpressionType(MandatoryExpr(node.left));
|
||||
if (!IsPrimitiveType(leftExprType) && !IsMatrixType(leftExprType) && !IsVectorType(leftExprType))
|
||||
throw AstError{ "left expression type does not support binary operation" };
|
||||
|
||||
ExpressionType rightExprType = GetExpressionType(MandatoryExpr(node.right), m_context->cache);
|
||||
ExpressionType rightExprType = GetExpressionType(MandatoryExpr(node.right));
|
||||
if (!IsPrimitiveType(rightExprType) && !IsMatrixType(rightExprType) && !IsVectorType(rightExprType))
|
||||
throw AstError{ "right expression type does not support binary operation" };
|
||||
|
||||
|
|
@ -201,12 +186,18 @@ namespace Nz::ShaderAst
|
|||
if (leftType == PrimitiveType::Boolean)
|
||||
throw AstError{ "this operation is not supported for booleans" };
|
||||
|
||||
[[fallthrough]];
|
||||
TypeMustMatch(node.left, node.right);
|
||||
|
||||
node.cachedExpressionType = PrimitiveType::Boolean;
|
||||
break;
|
||||
|
||||
case BinaryType::Add:
|
||||
case BinaryType::CompEq:
|
||||
case BinaryType::CompNe:
|
||||
case BinaryType::Subtract:
|
||||
TypeMustMatch(node.left, node.right);
|
||||
|
||||
node.cachedExpressionType = leftExprType;
|
||||
break;
|
||||
|
||||
case BinaryType::Multiply:
|
||||
|
|
@ -219,9 +210,20 @@ namespace Nz::ShaderAst
|
|||
case PrimitiveType::UInt32:
|
||||
{
|
||||
if (IsMatrixType(rightExprType))
|
||||
{
|
||||
TypeMustMatch(leftType, std::get<MatrixType>(rightExprType).type);
|
||||
node.cachedExpressionType = rightExprType;
|
||||
}
|
||||
else if (IsPrimitiveType(rightExprType))
|
||||
{
|
||||
TypeMustMatch(leftType, rightExprType);
|
||||
node.cachedExpressionType = leftExprType;
|
||||
}
|
||||
else if (IsVectorType(rightExprType))
|
||||
{
|
||||
TypeMustMatch(leftType, std::get<VectorType>(rightExprType).type);
|
||||
node.cachedExpressionType = rightExprType;
|
||||
}
|
||||
else
|
||||
throw AstError{ "incompatible types" };
|
||||
|
||||
|
|
@ -248,18 +250,29 @@ namespace Nz::ShaderAst
|
|||
case BinaryType::CompLt:
|
||||
case BinaryType::CompEq:
|
||||
case BinaryType::CompNe:
|
||||
TypeMustMatch(node.left, node.right);
|
||||
node.cachedExpressionType = PrimitiveType::Boolean;
|
||||
break;
|
||||
|
||||
case BinaryType::Add:
|
||||
case BinaryType::Subtract:
|
||||
TypeMustMatch(node.left, node.right);
|
||||
node.cachedExpressionType = leftExprType;
|
||||
break;
|
||||
|
||||
case BinaryType::Multiply:
|
||||
case BinaryType::Divide:
|
||||
{
|
||||
if (IsMatrixType(rightExprType))
|
||||
{
|
||||
TypeMustMatch(leftExprType, rightExprType);
|
||||
node.cachedExpressionType = leftExprType; //< FIXME
|
||||
}
|
||||
else if (IsPrimitiveType(rightExprType))
|
||||
{
|
||||
TypeMustMatch(leftType.type, rightExprType);
|
||||
node.cachedExpressionType = leftExprType;
|
||||
}
|
||||
else if (IsVectorType(rightExprType))
|
||||
{
|
||||
const VectorType& rightType = std::get<VectorType>(rightExprType);
|
||||
|
|
@ -267,6 +280,8 @@ namespace Nz::ShaderAst
|
|||
|
||||
if (leftType.columnCount != rightType.componentCount)
|
||||
throw AstError{ "incompatible types" };
|
||||
|
||||
node.cachedExpressionType = rightExprType;
|
||||
}
|
||||
else
|
||||
throw AstError{ "incompatible types" };
|
||||
|
|
@ -275,7 +290,7 @@ namespace Nz::ShaderAst
|
|||
}
|
||||
else if (IsVectorType(leftExprType))
|
||||
{
|
||||
const MatrixType& leftType = std::get<MatrixType>(leftExprType);
|
||||
const VectorType& leftType = std::get<VectorType>(leftExprType);
|
||||
switch (node.op)
|
||||
{
|
||||
case BinaryType::CompGe:
|
||||
|
|
@ -284,16 +299,29 @@ namespace Nz::ShaderAst
|
|||
case BinaryType::CompLt:
|
||||
case BinaryType::CompEq:
|
||||
case BinaryType::CompNe:
|
||||
TypeMustMatch(node.left, node.right);
|
||||
node.cachedExpressionType = PrimitiveType::Boolean;
|
||||
break;
|
||||
|
||||
case BinaryType::Add:
|
||||
case BinaryType::Subtract:
|
||||
TypeMustMatch(node.left, node.right);
|
||||
node.cachedExpressionType = leftExprType;
|
||||
break;
|
||||
|
||||
case BinaryType::Multiply:
|
||||
case BinaryType::Divide:
|
||||
{
|
||||
if (IsPrimitiveType(rightExprType))
|
||||
{
|
||||
TypeMustMatch(leftType.type, rightExprType);
|
||||
node.cachedExpressionType = rightExprType;
|
||||
}
|
||||
else if (IsVectorType(rightExprType))
|
||||
{
|
||||
TypeMustMatch(leftType, rightExprType);
|
||||
node.cachedExpressionType = rightExprType;
|
||||
}
|
||||
else
|
||||
throw AstError{ "incompatible types" };
|
||||
}
|
||||
|
|
@ -303,11 +331,9 @@ namespace Nz::ShaderAst
|
|||
|
||||
void AstValidator::Visit(CastExpression& node)
|
||||
{
|
||||
RegisterScope(node);
|
||||
AstScopedVisitor::Visit(node);
|
||||
|
||||
AstRecursiveVisitor::Visit(node);
|
||||
|
||||
auto GetComponentCount = [](const ExpressionType& exprType) -> unsigned int
|
||||
auto GetComponentCount = [](const ExpressionType& exprType) -> std::size_t
|
||||
{
|
||||
if (IsPrimitiveType(exprType))
|
||||
return 1;
|
||||
|
|
@ -317,15 +343,15 @@ namespace Nz::ShaderAst
|
|||
throw AstError{ "wut" };
|
||||
};
|
||||
|
||||
unsigned int componentCount = 0;
|
||||
unsigned int requiredComponents = GetComponentCount(node.targetType);
|
||||
std::size_t componentCount = 0;
|
||||
std::size_t requiredComponents = GetComponentCount(node.targetType);
|
||||
|
||||
for (auto& exprPtr : node.expressions)
|
||||
{
|
||||
if (!exprPtr)
|
||||
break;
|
||||
|
||||
ExpressionType exprType = GetExpressionType(*exprPtr, m_context->cache);
|
||||
ExpressionType exprType = GetExpressionType(*exprPtr);
|
||||
if (!IsPrimitiveType(exprType) && !IsVectorType(exprType))
|
||||
throw AstError{ "incompatible type" };
|
||||
|
||||
|
|
@ -334,6 +360,40 @@ namespace Nz::ShaderAst
|
|||
|
||||
if (componentCount != requiredComponents)
|
||||
throw AstError{ "component count doesn't match required component count" };
|
||||
|
||||
node.cachedExpressionType = node.targetType;
|
||||
}
|
||||
|
||||
void AstValidator::Visit(ConstantExpression& node)
|
||||
{
|
||||
node.cachedExpressionType = std::visit([&](auto&& arg) -> ShaderAst::ExpressionType
|
||||
{
|
||||
using T = std::decay_t<decltype(arg)>;
|
||||
|
||||
if constexpr (std::is_same_v<T, bool>)
|
||||
return PrimitiveType::Boolean;
|
||||
else if constexpr (std::is_same_v<T, float>)
|
||||
return PrimitiveType::Float32;
|
||||
else if constexpr (std::is_same_v<T, Int32>)
|
||||
return PrimitiveType::Int32;
|
||||
else if constexpr (std::is_same_v<T, UInt32>)
|
||||
return PrimitiveType::UInt32;
|
||||
else if constexpr (std::is_same_v<T, Vector2f>)
|
||||
return VectorType{ 2, PrimitiveType::Float32 };
|
||||
else if constexpr (std::is_same_v<T, Vector3f>)
|
||||
return VectorType{ 3, PrimitiveType::Float32 };
|
||||
else if constexpr (std::is_same_v<T, Vector4f>)
|
||||
return VectorType{ 4, PrimitiveType::Float32 };
|
||||
else if constexpr (std::is_same_v<T, Vector2i32>)
|
||||
return VectorType{ 2, PrimitiveType::Int32 };
|
||||
else if constexpr (std::is_same_v<T, Vector3i32>)
|
||||
return VectorType{ 3, PrimitiveType::Int32 };
|
||||
else if constexpr (std::is_same_v<T, Vector4i32>)
|
||||
return VectorType{ 4, PrimitiveType::Int32 };
|
||||
else
|
||||
static_assert(AlwaysFalse<T>::value, "non-exhaustive visitor");
|
||||
}, node.value);
|
||||
|
||||
}
|
||||
|
||||
void AstValidator::Visit(ConditionalExpression& node)
|
||||
|
|
@ -341,37 +401,31 @@ namespace Nz::ShaderAst
|
|||
MandatoryExpr(node.truePath);
|
||||
MandatoryExpr(node.falsePath);
|
||||
|
||||
RegisterScope(node);
|
||||
AstScopedVisitor::Visit(node);
|
||||
|
||||
AstRecursiveVisitor::Visit(node);
|
||||
ExpressionType leftExprType = GetExpressionType(*node.truePath);
|
||||
if (leftExprType != GetExpressionType(*node.falsePath))
|
||||
throw AstError{ "true path type must match false path type" };
|
||||
|
||||
node.cachedExpressionType = leftExprType;
|
||||
//if (m_shader.FindConditionByName(node.conditionName) == ShaderAst::InvalidCondition)
|
||||
// throw AstError{ "condition not found" };
|
||||
}
|
||||
|
||||
void AstValidator::Visit(ConstantExpression& node)
|
||||
{
|
||||
RegisterScope(node);
|
||||
}
|
||||
|
||||
void AstValidator::Visit(IdentifierExpression& node)
|
||||
{
|
||||
assert(m_context);
|
||||
|
||||
if (!m_context->activeScopeId)
|
||||
throw AstError{ "no scope" };
|
||||
|
||||
RegisterScope(node);
|
||||
|
||||
const AstCache::Identifier* identifier = m_context->cache->FindIdentifier(*m_context->activeScopeId, node.identifier);
|
||||
const Identifier* identifier = FindIdentifier(node.identifier);
|
||||
if (!identifier)
|
||||
throw AstError{ "Unknown variable " + node.identifier };
|
||||
throw AstError{ "Unknown identifier " + node.identifier };
|
||||
|
||||
node.cachedExpressionType = ResolveAlias(std::get<Variable>(identifier->value).type);
|
||||
}
|
||||
|
||||
void AstValidator::Visit(IntrinsicExpression& node)
|
||||
{
|
||||
RegisterScope(node);
|
||||
|
||||
AstRecursiveVisitor::Visit(node);
|
||||
AstScopedVisitor::Visit(node);
|
||||
|
||||
switch (node.intrinsic)
|
||||
{
|
||||
|
|
@ -384,10 +438,11 @@ namespace Nz::ShaderAst
|
|||
for (auto& param : node.parameters)
|
||||
MandatoryExpr(param);
|
||||
|
||||
ExpressionType type = GetExpressionType(*node.parameters.front(), m_context->cache);
|
||||
ExpressionType type = GetExpressionType(*node.parameters.front());
|
||||
|
||||
for (std::size_t i = 1; i < node.parameters.size(); ++i)
|
||||
{
|
||||
if (type != GetExpressionType(MandatoryExpr(node.parameters[i]), m_context->cache))
|
||||
if (type != GetExpressionType(MandatoryExpr(node.parameters[i])))
|
||||
throw AstError{ "All type must match" };
|
||||
}
|
||||
|
||||
|
|
@ -402,11 +457,11 @@ namespace Nz::ShaderAst
|
|||
for (auto& param : node.parameters)
|
||||
MandatoryExpr(param);
|
||||
|
||||
if (!IsSamplerType(GetExpressionType(*node.parameters[0], m_context->cache)))
|
||||
if (!IsSamplerType(GetExpressionType(*node.parameters[0])))
|
||||
throw AstError{ "First parameter must be a sampler" };
|
||||
|
||||
if (!IsVectorType(GetExpressionType(*node.parameters[1], m_context->cache)))
|
||||
throw AstError{ "First parameter must be a vector" };
|
||||
if (!IsVectorType(GetExpressionType(*node.parameters[1])))
|
||||
throw AstError{ "Second parameter must be a vector" };
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -414,63 +469,89 @@ namespace Nz::ShaderAst
|
|||
{
|
||||
case IntrinsicType::CrossProduct:
|
||||
{
|
||||
if (GetExpressionType(*node.parameters[0]) != ExpressionType{ VectorType{ 3, PrimitiveType::Float32 } })
|
||||
ExpressionType type = GetExpressionType(*node.parameters.front());
|
||||
if (type != ExpressionType{ VectorType{ 3, PrimitiveType::Float32 } })
|
||||
throw AstError{ "CrossProduct only works with vec3<f32> expressions" };
|
||||
|
||||
node.cachedExpressionType = std::move(type);
|
||||
break;
|
||||
}
|
||||
|
||||
case IntrinsicType::DotProduct:
|
||||
{
|
||||
ExpressionType type = GetExpressionType(*node.parameters.front());
|
||||
if (!IsVectorType(type))
|
||||
throw AstError{ "DotProduct expects vector types" };
|
||||
|
||||
node.cachedExpressionType = std::get<VectorType>(type).type;
|
||||
break;
|
||||
}
|
||||
|
||||
case IntrinsicType::SampleTexture:
|
||||
{
|
||||
node.cachedExpressionType = VectorType{ 4, std::get<SamplerType>(GetExpressionType(*node.parameters.front())).sampledType };
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void AstValidator::Visit(SwizzleExpression& node)
|
||||
{
|
||||
RegisterScope(node);
|
||||
|
||||
if (node.componentCount > 4)
|
||||
throw AstError{ "Cannot swizzle more than four elements" };
|
||||
|
||||
ExpressionType exprType = GetExpressionType(MandatoryExpr(node.expression), m_context->cache);
|
||||
if (!IsPrimitiveType(exprType) && !IsVectorType(exprType))
|
||||
throw AstError{ "Cannot swizzle this type" };
|
||||
MandatoryExpr(node.expression);
|
||||
|
||||
AstRecursiveVisitor::Visit(node);
|
||||
AstScopedVisitor::Visit(node);
|
||||
|
||||
ExpressionType exprType = GetExpressionType(*node.expression);
|
||||
if (IsPrimitiveType(exprType) || IsVectorType(exprType))
|
||||
{
|
||||
PrimitiveType baseType;
|
||||
if (IsPrimitiveType(exprType))
|
||||
baseType = std::get<PrimitiveType>(exprType);
|
||||
else
|
||||
baseType = std::get<VectorType>(exprType).type;
|
||||
|
||||
if (node.componentCount > 1)
|
||||
{
|
||||
node.cachedExpressionType = VectorType{
|
||||
node.componentCount,
|
||||
baseType
|
||||
};
|
||||
}
|
||||
else
|
||||
node.cachedExpressionType = baseType;
|
||||
}
|
||||
else
|
||||
throw AstError{ "Cannot swizzle this type" };
|
||||
}
|
||||
|
||||
void AstValidator::Visit(BranchStatement& node)
|
||||
{
|
||||
RegisterScope(node);
|
||||
|
||||
for (auto& condStatement : node.condStatements)
|
||||
{
|
||||
ExpressionType condType = GetExpressionType(MandatoryExpr(condStatement.condition), m_context->cache);
|
||||
ExpressionType condType = GetExpressionType(MandatoryExpr(condStatement.condition));
|
||||
if (!IsPrimitiveType(condType) || std::get<PrimitiveType>(condType) != PrimitiveType::Boolean)
|
||||
throw AstError{ "if expression must resolve to boolean type" };
|
||||
|
||||
MandatoryStatement(condStatement.statement);
|
||||
}
|
||||
|
||||
AstRecursiveVisitor::Visit(node);
|
||||
AstScopedVisitor::Visit(node);
|
||||
}
|
||||
|
||||
void AstValidator::Visit(ConditionalStatement& node)
|
||||
{
|
||||
MandatoryStatement(node.statement);
|
||||
|
||||
RegisterScope(node);
|
||||
|
||||
AstRecursiveVisitor::Visit(node);
|
||||
AstScopedVisitor::Visit(node);
|
||||
//if (m_shader.FindConditionByName(node.conditionName) == ShaderAst::InvalidCondition)
|
||||
// throw AstError{ "condition not found" };
|
||||
}
|
||||
|
||||
void AstValidator::Visit(DeclareExternalStatement& node)
|
||||
{
|
||||
RegisterScope(node);
|
||||
auto& scope = m_context->cache->scopes[*m_context->activeScopeId];
|
||||
|
||||
for (const auto& [attributeType, arg] : node.attributes)
|
||||
{
|
||||
switch (attributeType)
|
||||
|
|
@ -513,7 +594,7 @@ namespace Nz::ShaderAst
|
|||
throw AstError{ "attribute layout requires a string parameter" };
|
||||
|
||||
if (std::get<std::string>(arg) != "std140")
|
||||
throw AstError{ "unknow layout type" };
|
||||
throw AstError{ "unknown layout type" };
|
||||
|
||||
hasLayout = true;
|
||||
break;
|
||||
|
|
@ -528,14 +609,9 @@ namespace Nz::ShaderAst
|
|||
throw AstError{ "External variable " + extVar.name + " is already declared" };
|
||||
|
||||
m_context->declaredExternalVar.insert(extVar.name);
|
||||
|
||||
ExpressionType subType = extVar.type;
|
||||
if (IsUniformType(subType))
|
||||
subType = IdentifierType{ std::get<UniformType>(subType).containedType };
|
||||
|
||||
auto& identifier = scope.identifiers.emplace_back();
|
||||
identifier = AstCache::Identifier{ extVar.name, AstCache::Variable { std::move(subType) } };
|
||||
}
|
||||
|
||||
AstScopedVisitor::Visit(node);
|
||||
}
|
||||
|
||||
void AstValidator::Visit(DeclareFunctionStatement& node)
|
||||
|
|
@ -561,10 +637,10 @@ namespace Nz::ShaderAst
|
|||
|
||||
ShaderStageType stageType = it->second;
|
||||
|
||||
if (m_context->cache->entryFunctions[UnderlyingCast(stageType)])
|
||||
if (m_context->entryFunctions[UnderlyingCast(stageType)])
|
||||
throw AstError{ "the same entry type has been defined multiple times" };
|
||||
|
||||
m_context->cache->entryFunctions[UnderlyingCast(it->second)] = &node;
|
||||
m_context->entryFunctions[UnderlyingCast(it->second)] = &node;
|
||||
|
||||
if (node.parameters.size() > 1)
|
||||
throw AstError{ "entry functions can either take one struct parameter or no parameter" };
|
||||
|
|
@ -578,103 +654,41 @@ namespace Nz::ShaderAst
|
|||
}
|
||||
}
|
||||
|
||||
auto& scope = EnterScope();
|
||||
RegisterScope(node);
|
||||
|
||||
for (auto& parameter : node.parameters)
|
||||
{
|
||||
auto& identifier = scope.identifiers.emplace_back();
|
||||
identifier = AstCache::Identifier{ parameter.name, AstCache::Variable { parameter.type } };
|
||||
}
|
||||
|
||||
for (auto& statement : node.statements)
|
||||
MandatoryStatement(statement).Visit(*this);
|
||||
MandatoryStatement(statement);
|
||||
|
||||
ExitScope();
|
||||
AstScopedVisitor::Visit(node);
|
||||
}
|
||||
|
||||
void AstValidator::Visit(DeclareStructStatement& node)
|
||||
{
|
||||
assert(m_context);
|
||||
|
||||
if (!m_context->activeScopeId)
|
||||
throw AstError{ "cannot declare variable without scope" };
|
||||
|
||||
RegisterScope(node);
|
||||
|
||||
//TODO: check members attributes
|
||||
|
||||
auto& scope = m_context->cache->scopes[*m_context->activeScopeId];
|
||||
|
||||
auto& identifier = scope.identifiers.emplace_back();
|
||||
identifier = AstCache::Identifier{ node.description.name, node.description };
|
||||
|
||||
AstRecursiveVisitor::Visit(node);
|
||||
}
|
||||
|
||||
void AstValidator::Visit(DeclareVariableStatement& node)
|
||||
{
|
||||
assert(m_context);
|
||||
|
||||
if (!m_context->activeScopeId)
|
||||
throw AstError{ "cannot declare variable without scope" };
|
||||
|
||||
RegisterScope(node);
|
||||
|
||||
auto& scope = m_context->cache->scopes[*m_context->activeScopeId];
|
||||
|
||||
auto& identifier = scope.identifiers.emplace_back();
|
||||
identifier = AstCache::Identifier{ node.varName, AstCache::Variable { node.varType } };
|
||||
|
||||
AstRecursiveVisitor::Visit(node);
|
||||
AstScopedVisitor::Visit(node);
|
||||
}
|
||||
|
||||
void AstValidator::Visit(ExpressionStatement& node)
|
||||
{
|
||||
RegisterScope(node);
|
||||
|
||||
MandatoryExpr(node.expression);
|
||||
|
||||
AstRecursiveVisitor::Visit(node);
|
||||
AstScopedVisitor::Visit(node);
|
||||
}
|
||||
|
||||
void AstValidator::Visit(MultiStatement& node)
|
||||
{
|
||||
assert(m_context);
|
||||
|
||||
EnterScope();
|
||||
|
||||
RegisterScope(node);
|
||||
|
||||
for (auto& statement : node.statements)
|
||||
MandatoryStatement(statement);
|
||||
|
||||
ExitScope();
|
||||
|
||||
AstRecursiveVisitor::Visit(node);
|
||||
AstScopedVisitor::Visit(node);
|
||||
}
|
||||
|
||||
void AstValidator::Visit(ReturnStatement& node)
|
||||
{
|
||||
RegisterScope(node);
|
||||
|
||||
/*if (m_context->currentFunction->returnType != ShaderExpressionType(BasicType::Void))
|
||||
{
|
||||
if (GetExpressionType(MandatoryExpr(node.returnExpr)) != m_context->currentFunction->returnType)
|
||||
throw AstError{ "Return type doesn't match function return type" };
|
||||
}
|
||||
else
|
||||
{
|
||||
if (node.returnExpr)
|
||||
throw AstError{ "Unexpected expression for return (function doesn't return)" };
|
||||
}*/
|
||||
|
||||
AstRecursiveVisitor::Visit(node);
|
||||
}
|
||||
|
||||
bool ValidateAst(StatementPtr& node, std::string* error, AstCache* cache)
|
||||
bool ValidateAst(StatementPtr& node, std::string* error)
|
||||
{
|
||||
AstValidator validator;
|
||||
return validator.Validate(node, error, cache);
|
||||
return validator.Validate(node, error);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@
|
|||
|
||||
#include <Nazara/Shader/SpirvAstVisitor.hpp>
|
||||
#include <Nazara/Core/StackVector.hpp>
|
||||
#include <Nazara/Shader/ShaderAstExpressionType.hpp>
|
||||
#include <Nazara/Shader/SpirvSection.hpp>
|
||||
#include <Nazara/Shader/SpirvExpressionLoad.hpp>
|
||||
#include <Nazara/Shader/SpirvExpressionStore.hpp>
|
||||
|
|
@ -39,13 +38,13 @@ namespace Nz
|
|||
|
||||
void SpirvAstVisitor::Visit(ShaderAst::BinaryExpression& node)
|
||||
{
|
||||
ShaderAst::ExpressionType resultExprType = ShaderAst::GetExpressionType(node, m_cache);
|
||||
ShaderAst::ExpressionType resultExprType = GetExpressionType(node);
|
||||
assert(IsPrimitiveType(resultExprType));
|
||||
|
||||
ShaderAst::ExpressionType leftExprType = ShaderAst::GetExpressionType(*node.left, m_cache);
|
||||
ShaderAst::ExpressionType leftExprType = GetExpressionType(*node.left);
|
||||
assert(IsPrimitiveType(leftExprType));
|
||||
|
||||
ShaderAst::ExpressionType rightExprType = ShaderAst::GetExpressionType(*node.right, m_cache);
|
||||
ShaderAst::ExpressionType rightExprType = GetExpressionType(*node.right);
|
||||
assert(IsPrimitiveType(rightExprType));
|
||||
|
||||
ShaderAst::PrimitiveType resultType = std::get<ShaderAst::PrimitiveType>(resultExprType);
|
||||
|
|
@ -582,7 +581,7 @@ namespace Nz
|
|||
{
|
||||
case ShaderAst::IntrinsicType::DotProduct:
|
||||
{
|
||||
ShaderAst::ExpressionType vecExprType = GetExpressionType(*node.parameters[0], m_cache);
|
||||
ShaderAst::ExpressionType vecExprType = GetExpressionType(*node.parameters[0]);
|
||||
assert(IsVectorType(vecExprType));
|
||||
|
||||
const ShaderAst::VectorType& vecType = std::get<ShaderAst::VectorType>(vecExprType);
|
||||
|
|
@ -626,7 +625,7 @@ namespace Nz
|
|||
|
||||
void SpirvAstVisitor::Visit(ShaderAst::SwizzleExpression& node)
|
||||
{
|
||||
ShaderAst::ExpressionType targetExprType = ShaderAst::GetExpressionType(node, m_cache);
|
||||
ShaderAst::ExpressionType targetExprType = GetExpressionType(node);
|
||||
assert(IsPrimitiveType(targetExprType));
|
||||
|
||||
ShaderAst::PrimitiveType targetType = std::get<ShaderAst::PrimitiveType>(targetExprType);
|
||||
|
|
|
|||
|
|
@ -50,6 +50,11 @@ namespace Nz
|
|||
return Compare(lhs.parameters, rhs.parameters) && Compare(lhs.returnType, rhs.returnType);
|
||||
}
|
||||
|
||||
bool Compare(const Identifier& lhs, const Identifier& rhs) const
|
||||
{
|
||||
return lhs.name == rhs.name;
|
||||
}
|
||||
|
||||
bool Compare(const Image& lhs, const Image& rhs) const
|
||||
{
|
||||
return lhs.arrayed == rhs.arrayed
|
||||
|
|
@ -226,6 +231,11 @@ namespace Nz
|
|||
void Register(const Integer&) {}
|
||||
void Register(const Void&) {}
|
||||
|
||||
void Register(const Identifier& identifier)
|
||||
{
|
||||
Register(identifier);
|
||||
}
|
||||
|
||||
void Register(const Image& image)
|
||||
{
|
||||
Register(image.sampledType);
|
||||
|
|
@ -456,6 +466,11 @@ namespace Nz
|
|||
UInt32 SpirvConstantCache::Register(Type t)
|
||||
{
|
||||
AnyType& type = t.type;
|
||||
if (std::holds_alternative<Identifier>(type))
|
||||
{
|
||||
assert(m_identifierCallback);
|
||||
return Register(*m_identifierCallback(std::get<Identifier>(type).name));
|
||||
}
|
||||
|
||||
DepRegisterer registerer(*this);
|
||||
registerer.Register(type);
|
||||
|
|
@ -487,6 +502,11 @@ namespace Nz
|
|||
return it.value();
|
||||
}
|
||||
|
||||
void SpirvConstantCache::SetIdentifierCallback(IdentifierCallback callback)
|
||||
{
|
||||
m_identifierCallback = std::move(callback);
|
||||
}
|
||||
|
||||
void SpirvConstantCache::Write(SpirvSection& annotations, SpirvSection& constants, SpirvSection& debugInfos)
|
||||
{
|
||||
for (auto&& [object, id] : m_internal->ids)
|
||||
|
|
@ -597,7 +617,7 @@ namespace Nz
|
|||
return std::make_shared<Type>(Pointer{
|
||||
BuildType(type),
|
||||
storageClass
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
auto SpirvConstantCache::BuildType(const ShaderAst::ExpressionType& type) -> TypePtr
|
||||
|
|
@ -605,37 +625,16 @@ namespace Nz
|
|||
return std::visit([&](auto&& arg) -> TypePtr
|
||||
{
|
||||
return BuildType(arg);
|
||||
/*else if constexpr (std::is_same_v<T, std::string>)
|
||||
{
|
||||
// Register struct members type
|
||||
const auto& structs = shader.GetStructs();
|
||||
auto it = std::find_if(structs.begin(), structs.end(), [&](const auto& s) { return s.name == arg; });
|
||||
if (it == structs.end())
|
||||
throw std::runtime_error("struct " + arg + " has not been defined");
|
||||
|
||||
const ShaderAst::Struct& s = *it;
|
||||
|
||||
Structure sType;
|
||||
sType.name = s.name;
|
||||
|
||||
for (const auto& member : s.members)
|
||||
{
|
||||
auto& sMembers = sType.members.emplace_back();
|
||||
sMembers.name = member.name;
|
||||
sMembers.type = BuildType(shader, member.type);
|
||||
}
|
||||
|
||||
return std::make_shared<Type>(std::move(sType));
|
||||
return nullptr;
|
||||
}
|
||||
else
|
||||
static_assert(AlwaysFalse<T>::value, "non-exhaustive visitor");*/
|
||||
}, type);
|
||||
}
|
||||
|
||||
auto SpirvConstantCache::BuildType(const ShaderAst::IdentifierType& type) -> TypePtr
|
||||
{
|
||||
throw std::runtime_error("unexpected type");
|
||||
return std::make_shared<Type>(
|
||||
Identifier{
|
||||
type.name
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
auto SpirvConstantCache::BuildType(const ShaderAst::PrimitiveType& type) -> TypePtr
|
||||
|
|
@ -691,6 +690,21 @@ namespace Nz
|
|||
return std::make_shared<Type>(SampledImage{ std::make_shared<Type>(imageType) });
|
||||
}
|
||||
|
||||
auto SpirvConstantCache::BuildType(const ShaderAst::StructDescription& structDesc) -> TypePtr
|
||||
{
|
||||
Structure sType;
|
||||
sType.name = structDesc.name;
|
||||
|
||||
for (const auto& member : structDesc.members)
|
||||
{
|
||||
auto& sMembers = sType.members.emplace_back();
|
||||
sMembers.name = member.name;
|
||||
sMembers.type = BuildType(member.type);
|
||||
}
|
||||
|
||||
return std::make_shared<Type>(std::move(sType));
|
||||
}
|
||||
|
||||
auto SpirvConstantCache::BuildType(const ShaderAst::VectorType& type) -> TypePtr
|
||||
{
|
||||
return std::make_shared<Type>(Vector{ BuildType(type.type), UInt32(type.componentCount) });
|
||||
|
|
@ -767,6 +781,8 @@ namespace Nz
|
|||
appender(GetId(*param));
|
||||
});
|
||||
}
|
||||
else if constexpr (std::is_same_v<T, Identifier>)
|
||||
throw std::runtime_error("unexpected identifier");
|
||||
else if constexpr (std::is_same_v<T, Image>)
|
||||
{
|
||||
UInt32 depth;
|
||||
|
|
@ -915,6 +931,8 @@ namespace Nz
|
|||
}
|
||||
else if constexpr (std::is_same_v<T, Function>)
|
||||
throw std::runtime_error("unexpected function as struct member");
|
||||
else if constexpr (std::is_same_v<T, Identifier>)
|
||||
throw std::runtime_error("unexpected identifier");
|
||||
else if constexpr (std::is_same_v<T, Image> || std::is_same_v<T, SampledImage>)
|
||||
throw std::runtime_error("unexpected opaque type as struct member");
|
||||
else if constexpr (std::is_same_v<T, Void>)
|
||||
|
|
|
|||
|
|
@ -25,18 +25,26 @@ namespace Nz
|
|||
{
|
||||
namespace
|
||||
{
|
||||
class PreVisitor : public ShaderAst::AstRecursiveVisitor
|
||||
class PreVisitor : public ShaderAst::AstScopedVisitor
|
||||
{
|
||||
public:
|
||||
using ExtInstList = std::unordered_set<std::string>;
|
||||
using LocalContainer = std::unordered_set<ShaderAst::ExpressionType>;
|
||||
using FunctionContainer = std::vector<std::reference_wrapper<ShaderAst::DeclareFunctionStatement>>;
|
||||
|
||||
PreVisitor(ShaderAst::AstCache* cache, const SpirvWriter::States& conditions, SpirvConstantCache& constantCache) :
|
||||
m_cache(cache),
|
||||
PreVisitor(const SpirvWriter::States& conditions, SpirvConstantCache& constantCache) :
|
||||
m_conditions(conditions),
|
||||
m_constantCache(constantCache)
|
||||
{
|
||||
m_constantCache.SetIdentifierCallback([&](const std::string& identifierName)
|
||||
{
|
||||
const Identifier* identifier = FindIdentifier(identifierName);
|
||||
if (!identifier)
|
||||
throw std::runtime_error("invalid identifier " + identifierName);
|
||||
|
||||
assert(std::holds_alternative<ShaderAst::StructDescription>(identifier->value));
|
||||
return SpirvConstantCache::BuildType(std::get<ShaderAst::StructDescription>(identifier->value));
|
||||
});
|
||||
}
|
||||
|
||||
void Visit(ShaderAst::AccessMemberExpression& node) override
|
||||
|
|
@ -74,7 +82,7 @@ namespace Nz
|
|||
m_constantCache.Register(*SpirvConstantCache::BuildConstant(arg));
|
||||
}, node.value);
|
||||
|
||||
AstRecursiveVisitor::Visit(node);
|
||||
AstScopedVisitor::Visit(node);
|
||||
}
|
||||
|
||||
void Visit(ShaderAst::DeclareFunctionStatement& node) override
|
||||
|
|
@ -87,11 +95,13 @@ namespace Nz
|
|||
|
||||
m_constantCache.Register(*SpirvConstantCache::BuildFunctionType(node.returnType, parameterTypes));
|
||||
|
||||
AstRecursiveVisitor::Visit(node);
|
||||
AstScopedVisitor::Visit(node);
|
||||
}
|
||||
|
||||
void Visit(ShaderAst::DeclareStructStatement& node) override
|
||||
{
|
||||
AstScopedVisitor::Visit(node);
|
||||
|
||||
SpirvConstantCache::Structure sType;
|
||||
sType.name = node.description.name;
|
||||
|
||||
|
|
@ -107,21 +117,21 @@ namespace Nz
|
|||
|
||||
void Visit(ShaderAst::DeclareVariableStatement& node) override
|
||||
{
|
||||
m_constantCache.Register(*SpirvConstantCache::BuildType(node.varType));
|
||||
AstScopedVisitor::Visit(node);
|
||||
|
||||
AstRecursiveVisitor::Visit(node);
|
||||
m_constantCache.Register(*SpirvConstantCache::BuildType(node.varType));
|
||||
}
|
||||
|
||||
void Visit(ShaderAst::IdentifierExpression& node) override
|
||||
{
|
||||
m_constantCache.Register(*SpirvConstantCache::BuildType(GetExpressionType(node, m_cache)));
|
||||
m_constantCache.Register(*SpirvConstantCache::BuildType(node.cachedExpressionType.value()));
|
||||
|
||||
AstRecursiveVisitor::Visit(node);
|
||||
AstScopedVisitor::Visit(node);
|
||||
}
|
||||
|
||||
void Visit(ShaderAst::IntrinsicExpression& node) override
|
||||
{
|
||||
AstRecursiveVisitor::Visit(node);
|
||||
AstScopedVisitor::Visit(node);
|
||||
|
||||
switch (node.intrinsic)
|
||||
{
|
||||
|
|
@ -140,7 +150,6 @@ namespace Nz
|
|||
FunctionContainer funcs;
|
||||
|
||||
private:
|
||||
ShaderAst::AstCache* m_cache;
|
||||
const SpirvWriter::States& m_conditions;
|
||||
SpirvConstantCache& m_constantCache;
|
||||
};
|
||||
|
|
@ -214,7 +223,7 @@ namespace Nz
|
|||
std::vector<UInt32> SpirvWriter::Generate(ShaderAst::StatementPtr& shader, const States& conditions)
|
||||
{
|
||||
std::string error;
|
||||
if (!ShaderAst::ValidateAst(shader, &error, &m_context.cache))
|
||||
if (!ShaderAst::ValidateAst(shader, &error))
|
||||
throw std::runtime_error("Invalid shader AST: " + error);
|
||||
|
||||
m_context.states = &conditions;
|
||||
|
|
@ -229,7 +238,7 @@ namespace Nz
|
|||
ShaderAst::AstCloner cloner;
|
||||
|
||||
// Register all extended instruction sets
|
||||
PreVisitor preVisitor(&m_context.cache, conditions, state.constantTypeCache);
|
||||
PreVisitor preVisitor(conditions, state.constantTypeCache);
|
||||
shader->Visit(preVisitor);
|
||||
|
||||
for (const std::string& extInst : preVisitor.extInsts)
|
||||
|
|
@ -397,7 +406,7 @@ namespace Nz
|
|||
state.parameterIds.emplace(param.name, std::move(parameterData));
|
||||
}
|
||||
|
||||
SpirvAstVisitor visitor(*this, state.functionBlocks, &m_context.cache);
|
||||
SpirvAstVisitor visitor(*this, state.functionBlocks);
|
||||
for (const auto& statement : func.statements)
|
||||
statement->Visit(visitor);
|
||||
|
||||
|
|
@ -419,7 +428,7 @@ namespace Nz
|
|||
|
||||
for (std::size_t i = 0; i < ShaderStageTypeCount; ++i)
|
||||
{
|
||||
const ShaderAst::DeclareFunctionStatement* statement = m_context.cache.entryFunctions[i];
|
||||
/*const ShaderAst::DeclareFunctionStatement* statement = m_context.cache.entryFunctions[i];
|
||||
if (!statement)
|
||||
continue;
|
||||
|
||||
|
|
@ -462,7 +471,7 @@ namespace Nz
|
|||
});
|
||||
|
||||
if (stage == ShaderStageType::Fragment)
|
||||
state.header.Append(SpirvOp::OpExecutionMode, entryFunc.id, SpirvExecutionMode::OriginUpperLeft);
|
||||
state.header.Append(SpirvOp::OpExecutionMode, entryFunc.id, SpirvExecutionMode::OriginUpperLeft);*/
|
||||
}
|
||||
|
||||
std::vector<UInt32> ret;
|
||||
|
|
|
|||
Loading…
Reference in New Issue