Shader: Rework scope handling

This commit is contained in:
Jérôme Leclercq 2021-04-04 20:31:09 +02:00
parent feffcfa6e5
commit f93a5bbdc1
23 changed files with 661 additions and 755 deletions

View File

@ -32,12 +32,9 @@
#include <Nazara/Shader/Config.hpp>
#include <Nazara/Shader/GlslWriter.hpp>
#include <Nazara/Shader/Shader.hpp>
#include <Nazara/Shader/ShaderAstCache.hpp>
#include <Nazara/Shader/ShaderAstCloner.hpp>
#include <Nazara/Shader/ShaderAstExpressionType.hpp>
#include <Nazara/Shader/ShaderAstExpressionVisitor.hpp>
#include <Nazara/Shader/ShaderAstExpressionVisitorExcept.hpp>
#include <Nazara/Shader/ShaderAstNodes.hpp>
#include <Nazara/Shader/ShaderAstOptimizer.hpp>
#include <Nazara/Shader/ShaderAstRecursiveVisitor.hpp>
#include <Nazara/Shader/ShaderAstSerializer.hpp>

View File

@ -9,7 +9,7 @@
#include <Nazara/Prerequisites.hpp>
#include <Nazara/Shader/Config.hpp>
#include <Nazara/Shader/ShaderAstRecursiveVisitor.hpp>
#include <Nazara/Shader/ShaderAstScopedVisitor.hpp>
#include <Nazara/Shader/ShaderWriter.hpp>
#include <set>
#include <sstream>
@ -17,7 +17,7 @@
namespace Nz
{
class NAZARA_SHADER_API GlslWriter : public ShaderWriter, public ShaderAst::AstRecursiveVisitor
class NAZARA_SHADER_API GlslWriter : public ShaderWriter, public ShaderAst::AstScopedVisitor
{
public:
struct Environment;
@ -57,8 +57,8 @@ namespace Nz
template<typename T> void Append(const T& param);
template<typename T1, typename T2, typename... Args> void Append(const T1& firstParam, const T2& secondParam, Args&&... params);
void AppendCommentSection(const std::string& section);
void AppendEntryPoint(ShaderStageType shaderStage);
void AppendField(std::size_t scopeId, const std::string& structName, const std::string* memberIdentifier, std::size_t remainingMembers);
void AppendEntryPoint(ShaderStageType shaderStage, ShaderAst::StatementPtr& shader);
void AppendField(const std::string& structName, const std::string* memberIdentifier, std::size_t remainingMembers);
void AppendLine(const std::string& txt = {});
template<typename... Args> void AppendLine(Args&&... params);

View File

@ -1,56 +0,0 @@
// Copyright (C) 2020 Jérôme Leclercq
// This file is part of the "Nazara Engine - Shader generator"
// For conditions of distribution and use, see copyright notice in Config.hpp
#pragma once
#ifndef NAZARA_SHADERASTCACHE_HPP
#define NAZARA_SHADERASTCACHE_HPP
#include <Nazara/Prerequisites.hpp>
#include <Nazara/Shader/Config.hpp>
#include <Nazara/Shader/ShaderAstExpressionType.hpp>
#include <Nazara/Utility/Enums.hpp>
namespace Nz::ShaderAst
{
struct AstCache
{
struct Identifier;
struct Alias
{
std::variant<ExpressionType> value;
};
struct Variable
{
ExpressionType type;
};
struct Identifier
{
std::string name;
std::variant<Alias, Variable, StructDescription> value;
};
struct Scope
{
std::optional<std::size_t> parentScopeIndex;
std::vector<Identifier> identifiers;
};
inline void Clear();
inline const Identifier* FindIdentifier(std::size_t startingScopeId, const std::string& identifierName) const;
inline std::size_t GetScopeId(const Node* node) const;
std::array<DeclareFunctionStatement*, ShaderStageTypeCount> entryFunctions = {};
std::unordered_map<const Expression*, ExpressionType> nodeExpressionType;
std::unordered_map<const Node*, std::size_t> scopeIdByNode;
std::vector<Scope> scopes;
};
}
#include <Nazara/Shader/ShaderAstCache.inl>
#endif

View File

@ -1,45 +0,0 @@
// Copyright (C) 2020 Jérôme Leclercq
// This file is part of the "Nazara Engine - Shader generator"
// For conditions of distribution and use, see copyright notice in Config.hpp
#include <Nazara/Shader/ShaderAstCache.hpp>
#include <Nazara/Shader/Debug.hpp>
namespace Nz::ShaderAst
{
inline void AstCache::Clear()
{
entryFunctions.fill(nullptr);
nodeExpressionType.clear();
scopeIdByNode.clear();
scopes.clear();
}
inline auto AstCache::FindIdentifier(std::size_t startingScopeId, const std::string& identifierName) const -> const Identifier*
{
assert(startingScopeId < scopes.size());
std::optional<std::size_t> scopeId = startingScopeId;
do
{
const auto& scope = scopes[*scopeId];
auto it = std::find_if(scope.identifiers.rbegin(), scope.identifiers.rend(), [&](const auto& identifier) { return identifier.name == identifierName; });
if (it != scope.identifiers.rend())
return &*it;
scopeId = scope.parentScopeIndex;
} while (scopeId);
return nullptr;
}
inline std::size_t AstCache::GetScopeId(const Node* node) const
{
auto it = scopeIdByNode.find(node);
assert(it != scopeIdByNode.end());
return it->second;
}
}
#include <Nazara/Shader/DebugOff.hpp>

View File

@ -33,7 +33,7 @@ namespace Nz::ShaderAst
ExpressionPtr CloneExpression(ExpressionPtr& expr);
StatementPtr CloneStatement(StatementPtr& statement);
virtual std::unique_ptr<DeclareFunctionStatement> Clone(DeclareFunctionStatement& node);
virtual StatementPtr Clone(DeclareFunctionStatement& node);
using AstExpressionVisitor::Visit;
using AstStatementVisitor::Visit;

View File

@ -1,58 +0,0 @@
// Copyright (C) 2020 Jérôme Leclercq
// This file is part of the "Nazara Engine - Shader generator"
// For conditions of distribution and use, see copyright notice in Config.hpp
#pragma once
#ifndef NAZARA_SHADERASTEXPRESSIONTYPE_HPP
#define NAZARA_SHADERASTEXPRESSIONTYPE_HPP
#include <Nazara/Prerequisites.hpp>
#include <Nazara/Shader/Config.hpp>
#include <Nazara/Shader/ShaderAstExpressionVisitor.hpp>
#include <Nazara/Shader/Ast/ExpressionType.hpp>
#include <vector>
namespace Nz::ShaderAst
{
struct AstCache;
class NAZARA_SHADER_API ExpressionTypeVisitor : public AstExpressionVisitor
{
public:
ExpressionTypeVisitor() = default;
ExpressionTypeVisitor(const ExpressionTypeVisitor&) = delete;
ExpressionTypeVisitor(ExpressionTypeVisitor&&) = delete;
~ExpressionTypeVisitor() = default;
ExpressionType GetExpressionType(Expression& expression, AstCache* cache);
ExpressionTypeVisitor& operator=(const ExpressionTypeVisitor&) = delete;
ExpressionTypeVisitor& operator=(ExpressionTypeVisitor&&) = delete;
private:
ExpressionType GetExpressionTypeInternal(Expression& expression);
ExpressionType ResolveAlias(Expression& expression, ExpressionType expressionType);
void Visit(Expression& expression);
void Visit(AccessMemberExpression& node) override;
void Visit(AssignExpression& node) override;
void Visit(BinaryExpression& node) override;
void Visit(CastExpression& node) override;
void Visit(ConditionalExpression& node) override;
void Visit(ConstantExpression& node) override;
void Visit(IdentifierExpression& node) override;
void Visit(IntrinsicExpression& node) override;
void Visit(SwizzleExpression& node) override;
AstCache* m_cache;
std::optional<ExpressionType> m_lastExpressionType;
};
inline ExpressionType GetExpressionType(Expression& expression, AstCache* cache = nullptr);
}
#include <Nazara/Shader/ShaderAstExpressionType.inl>
#endif

View File

@ -1,17 +0,0 @@
// Copyright (C) 2020 Jérôme Leclercq
// This file is part of the "Nazara Engine - Shader generator"
// For conditions of distribution and use, see copyright notice in Config.hpp
#include <Nazara/Shader/ShaderAstExpressionType.hpp>
#include <Nazara/Shader/Debug.hpp>
namespace Nz::ShaderAst
{
inline ExpressionType GetExpressionType(Expression& expression, AstCache* cache)
{
ExpressionTypeVisitor visitor;
return visitor.GetExpressionType(expression, cache);
}
}
#include <Nazara/Shader/DebugOff.hpp>

View File

@ -0,0 +1,68 @@
// Copyright (C) 2020 Jérôme Leclercq
// This file is part of the "Nazara Engine - Shader generator"
// For conditions of distribution and use, see copyright notice in Config.hpp
#pragma once
#ifndef NAZARA_SHADER_SCOPED_VISITOR_HPP
#define NAZARA_SHADER_SCOPED_VISITOR_HPP
#include <Nazara/Prerequisites.hpp>
#include <Nazara/Shader/Config.hpp>
#include <Nazara/Shader/ShaderAstRecursiveVisitor.hpp>
namespace Nz::ShaderAst
{
class NAZARA_SHADER_API AstScopedVisitor : public AstRecursiveVisitor
{
public:
struct Identifier;
AstScopedVisitor() = default;
~AstScopedVisitor() = default;
inline const Identifier* FindIdentifier(const std::string_view& identifierName) const;
void ScopedVisit(StatementPtr& nodePtr);
using AstRecursiveVisitor::Visit;
void Visit(BranchStatement& node) override;
void Visit(ConditionalStatement& node) override;
void Visit(DeclareExternalStatement& node) override;
void Visit(DeclareFunctionStatement& node) override;
void Visit(DeclareStructStatement& node) override;
void Visit(DeclareVariableStatement& node) override;
void Visit(MultiStatement& node) override;
struct Alias
{
std::variant<ExpressionType> value;
};
struct Variable
{
ExpressionType type;
};
struct Identifier
{
std::string name;
std::variant<Alias, Variable, StructDescription> value;
};
protected:
void PushScope();
void PopScope();
inline void RegisterStruct(StructDescription structDesc);
inline void RegisterVariable(std::string name, ExpressionType type);
private:
std::vector<Identifier> m_identifiersInScope;
std::vector<std::size_t> m_scopeSizes;
};
}
#include <Nazara/Shader/ShaderAstScopedVisitor.inl>
#endif

View File

@ -0,0 +1,38 @@
// Copyright (C) 2020 Jérôme Leclercq
// This file is part of the "Nazara Engine - Shader generator"
// For conditions of distribution and use, see copyright notice in Config.hpp
#include <Nazara/Shader/ShaderAstScopedVisitor.hpp>
#include <Nazara/Shader/Debug.hpp>
namespace Nz::ShaderAst
{
inline auto AstScopedVisitor::FindIdentifier(const std::string_view& identifierName) const -> const Identifier*
{
auto it = std::find_if(m_identifiersInScope.rbegin(), m_identifiersInScope.rend(), [&](const Identifier& identifier) { return identifier.name == identifierName; });
if (it == m_identifiersInScope.rend())
return nullptr;
return &*it;
}
inline void AstScopedVisitor::RegisterStruct(StructDescription structDesc)
{
std::string name = structDesc.name;
m_identifiersInScope.push_back({
std::move(name),
std::move(structDesc)
});
}
inline void AstScopedVisitor::RegisterVariable(std::string name, ExpressionType type)
{
m_identifiersInScope.push_back({
std::move(name),
Variable { std::move(type) }
});
}
}
#include <Nazara/Shader/DebugOff.hpp>

View File

@ -9,13 +9,12 @@
#include <Nazara/Prerequisites.hpp>
#include <Nazara/Shader/Config.hpp>
#include <Nazara/Shader/ShaderAstCache.hpp>
#include <Nazara/Shader/ShaderAstRecursiveVisitor.hpp>
#include <Nazara/Shader/ShaderAstScopedVisitor.hpp>
#include <Nazara/Utility/Enums.hpp>
namespace Nz::ShaderAst
{
class NAZARA_SHADER_API AstValidator : public AstRecursiveVisitor
class NAZARA_SHADER_API AstValidator final : public AstScopedVisitor
{
public:
inline AstValidator();
@ -23,28 +22,24 @@ namespace Nz::ShaderAst
AstValidator(AstValidator&&) = delete;
~AstValidator() = default;
bool Validate(StatementPtr& node, std::string* error = nullptr, AstCache* cache = nullptr);
bool Validate(StatementPtr& node, std::string* error = nullptr);
private:
const ExpressionType& GetExpressionType(Expression& expression);
Expression& MandatoryExpr(ExpressionPtr& node);
Statement& MandatoryStatement(StatementPtr& node);
void TypeMustMatch(ExpressionPtr& left, ExpressionPtr& right);
void TypeMustMatch(const ExpressionType& left, const ExpressionType& right);
ExpressionType CheckField(const std::string& structName, const std::string* memberIdentifier, std::size_t remainingMembers);
AstCache::Scope& EnterScope();
void ExitScope();
void RegisterExpressionType(Expression& node, ExpressionType expressionType);
void RegisterScope(Node& node);
const ExpressionType& ResolveAlias(const ExpressionType& expressionType);
void Visit(AccessMemberExpression& node) override;
void Visit(AssignExpression& node) override;
void Visit(BinaryExpression& node) override;
void Visit(CastExpression& node) override;
void Visit(ConditionalExpression& node) override;
void Visit(ConstantExpression& node) override;
void Visit(ConditionalExpression& node) override;
void Visit(IdentifierExpression& node) override;
void Visit(IntrinsicExpression& node) override;
void Visit(SwizzleExpression& node) override;
@ -54,17 +49,15 @@ namespace Nz::ShaderAst
void Visit(DeclareExternalStatement& node) override;
void Visit(DeclareFunctionStatement& node) override;
void Visit(DeclareStructStatement& node) override;
void Visit(DeclareVariableStatement& node) override;
void Visit(ExpressionStatement& node) override;
void Visit(MultiStatement& node) override;
void Visit(ReturnStatement& node) override;
struct Context;
Context* m_context;
};
NAZARA_SHADER_API bool ValidateAst(StatementPtr& node, std::string* error = nullptr, AstCache* cache = nullptr);
NAZARA_SHADER_API bool ValidateAst(StatementPtr& node, std::string* error = nullptr);
}
#include <Nazara/Shader/ShaderAstValidator.inl>

View File

@ -56,6 +56,8 @@ namespace Nz::ShaderAst
Expression& operator=(const Expression&) = delete;
Expression& operator=(Expression&&) noexcept = default;
std::optional<ExpressionType> cachedExpressionType;
};
struct NAZARA_SHADER_API AccessMemberExpression : public Expression

View File

@ -21,7 +21,7 @@ namespace Nz
class NAZARA_SHADER_API SpirvAstVisitor : public ShaderAst::ExpressionVisitorExcept, public ShaderAst::StatementVisitorExcept
{
public:
inline SpirvAstVisitor(SpirvWriter& writer, std::vector<SpirvBlock>& blocks, ShaderAst::AstCache* cache);
inline SpirvAstVisitor(SpirvWriter& writer, std::vector<SpirvBlock>& blocks);
SpirvAstVisitor(const SpirvAstVisitor&) = delete;
SpirvAstVisitor(SpirvAstVisitor&&) = delete;
~SpirvAstVisitor() = default;
@ -53,10 +53,10 @@ namespace Nz
SpirvAstVisitor& operator=(SpirvAstVisitor&&) = delete;
private:
inline const ShaderAst::ExpressionType& GetExpressionType(ShaderAst::Expression& expr) const;
void PushResultId(UInt32 value);
UInt32 PopResultId();
ShaderAst::AstCache* m_cache;
std::vector<SpirvBlock>& m_blocks;
std::vector<UInt32> m_resultIds;
SpirvBlock* m_currentBlock;

View File

@ -7,13 +7,18 @@
namespace Nz
{
inline SpirvAstVisitor::SpirvAstVisitor(SpirvWriter& writer, std::vector<SpirvBlock>& blocks, ShaderAst::AstCache* cache) :
m_cache(cache),
inline SpirvAstVisitor::SpirvAstVisitor(SpirvWriter& writer, std::vector<SpirvBlock>& blocks) :
m_blocks(blocks),
m_writer(writer)
{
m_currentBlock = &m_blocks.back();
}
inline const ShaderAst::ExpressionType& SpirvAstVisitor::GetExpressionType(ShaderAst::Expression& expr) const
{
assert(expr.cachedExpressionType);
return expr.cachedExpressionType.value();
}
}
#include <Nazara/Shader/DebugOff.hpp>

View File

@ -31,11 +31,14 @@ namespace Nz
~SpirvConstantCache();
struct Constant;
struct Identifier;
struct Type;
using ConstantPtr = std::shared_ptr<Constant>;
using TypePtr = std::shared_ptr<Type>;
using IdentifierCallback = std::function<TypePtr(const std::string& identifier)>;
struct Bool {};
struct Float
@ -63,6 +66,11 @@ namespace Nz
UInt32 columnCount;
};
struct Identifier
{
std::string name;
};
struct Image
{
std::optional<SpirvAccessQualifier> qualifier;
@ -104,7 +112,7 @@ namespace Nz
std::vector<Member> members;
};
using AnyType = std::variant<Bool, Float, Function, Image, Integer, Matrix, Pointer, SampledImage, Structure, Vector, Void>;
using AnyType = std::variant<Bool, Float, Function, Identifier, Image, Integer, Matrix, Pointer, SampledImage, Structure, Vector, Void>;
struct ConstantBool
{
@ -166,6 +174,8 @@ namespace Nz
UInt32 Register(Type t);
UInt32 Register(Variable v);
void SetIdentifierCallback(IdentifierCallback callback);
void Write(SpirvSection& annotations, SpirvSection& constants, SpirvSection& debugInfos);
SpirvConstantCache& operator=(const SpirvConstantCache& cache) = delete;
@ -181,6 +191,7 @@ namespace Nz
static TypePtr BuildType(const ShaderAst::NoType& type);
static TypePtr BuildType(const ShaderAst::PrimitiveType& type);
static TypePtr BuildType(const ShaderAst::SamplerType& type);
static TypePtr BuildType(const ShaderAst::StructDescription& structDesc);
static TypePtr BuildType(const ShaderAst::VectorType& type);
private:
@ -193,6 +204,7 @@ namespace Nz
void WriteStruct(const Structure& structData, UInt32 resultId, SpirvSection& annotations, SpirvSection& constants, SpirvSection& debugInfos);
IdentifierCallback m_identifierCallback;
std::unique_ptr<Internal> m_internal;
};
}

View File

@ -8,7 +8,6 @@
#include <Nazara/Math/Algorithm.hpp>
#include <Nazara/Shader/ShaderBuilder.hpp>
#include <Nazara/Shader/ShaderAstCloner.hpp>
#include <Nazara/Shader/ShaderAstExpressionType.hpp>
#include <Nazara/Shader/ShaderAstUtils.hpp>
#include <Nazara/Shader/ShaderAstValidator.hpp>
#include <optional>
@ -22,31 +21,87 @@ namespace Nz
static const char* flipYUniformName = "_NzFlipValue";
static const char* overridenMain = "_NzMain";
struct AstAdapter : ShaderAst::AstCloner
//FIXME: Have this only once
std::unordered_map<std::string, ShaderStageType> s_entryPoints = {
{ "frag", ShaderStageType::Fragment },
{ "vert", ShaderStageType::Vertex },
};
struct PreVisitor : ShaderAst::AstCloner
{
using AstCloner::Clone;
std::unique_ptr<ShaderAst::DeclareFunctionStatement> Clone(ShaderAst::DeclareFunctionStatement& node) override
ShaderAst::StatementPtr Clone(ShaderAst::DeclareFunctionStatement& node) override
{
auto clone = AstCloner::Clone(node);
if (clone->name == "main")
clone->name = "_NzMain";
assert(clone->GetType() == ShaderAst::NodeType::DeclareFunctionStatement);
ShaderAst::DeclareFunctionStatement* func = static_cast<ShaderAst::DeclareFunctionStatement*>(clone.get());
bool hasEntryPoint = false;
for (auto& attribute : func->attributes)
{
if (attribute.type == ShaderAst::AttributeType::Entry)
{
auto it = s_entryPoints.find(std::get<std::string>(attribute.args));
assert(it != s_entryPoints.end());
if (it->second == selectedEntryPoint)
{
hasEntryPoint = true;
break;
}
}
}
if (!hasEntryPoint)
return ShaderBuilder::NoOp();
entryPoint = func;
if (func->name == "main")
func->name = "_NzMain";
return clone;
}
void Visit(ShaderAst::DeclareFunctionStatement& node)
{
if (removedEntryPoints.find(&node) != removedEntryPoints.end())
{
PushStatement(ShaderBuilder::NoOp());
return;
}
ShaderStageType selectedEntryPoint;
ShaderAst::DeclareFunctionStatement* entryPoint = nullptr;
};
AstCloner::Visit(node);
struct EntryFuncResolver : ShaderAst::AstScopedVisitor
{
void Visit(ShaderAst::DeclareFunctionStatement& node) override
{
if (&node != entryPoint)
return;
assert(node.parameters.size() == 1);
const ShaderAst::ExpressionType& inputType = node.parameters.front().type;
const ShaderAst::ExpressionType& outputType = node.returnType;
const Identifier* identifier;
assert(IsIdentifierType(node.parameters.front().type));
identifier = FindIdentifier(std::get<ShaderAst::IdentifierType>(inputType).name);
assert(identifier);
inputIdentifier = *identifier;
assert(IsIdentifierType(outputType));
identifier = FindIdentifier(std::get<ShaderAst::IdentifierType>(outputType).name);
assert(identifier);
outputIdentifier = *identifier;
}
std::unordered_set<ShaderAst::DeclareFunctionStatement*> removedEntryPoints;
Identifier inputIdentifier;
Identifier outputIdentifier;
ShaderAst::DeclareFunctionStatement* entryPoint;
};
struct Builtin
@ -64,7 +119,6 @@ namespace Nz
struct GlslWriter::State
{
const States* states = nullptr;
ShaderAst::AstCache cache;
ShaderAst::DeclareFunctionStatement* entryFunc = nullptr;
std::stringstream stream;
unsigned int indentLevel = 0;
@ -86,29 +140,18 @@ namespace Nz
});
std::string error;
if (!ShaderAst::ValidateAst(shader, &error, &state.cache))
if (!ShaderAst::ValidateAst(shader, &error))
throw std::runtime_error("Invalid shader AST: " + error);
state.entryFunc = state.cache.entryFunctions[UnderlyingCast(shaderStage)];
if (!state.entryFunc)
PreVisitor previsitor;
previsitor.selectedEntryPoint = shaderStage;
ShaderAst::StatementPtr adaptedShader = previsitor.Clone(shader);
if (!previsitor.entryPoint)
throw std::runtime_error("missing entry point");
AstAdapter adapter;
for (ShaderAst::DeclareFunctionStatement* entryFunc : state.cache.entryFunctions)
{
if (entryFunc != state.entryFunc)
adapter.removedEntryPoints.insert(entryFunc);
}
ShaderAst::StatementPtr adaptedShader = adapter.Clone(shader);
state.cache.Clear();
if (!ShaderAst::ValidateAst(adaptedShader, &error, &state.cache))
throw std::runtime_error("Internal error:" + error);
state.entryFunc = state.cache.entryFunctions[UnderlyingCast(shaderStage)];
assert(state.entryFunc);
state.entryFunc = previsitor.entryPoint;
unsigned int glslVersion;
if (m_environment.glES)
@ -190,10 +233,14 @@ namespace Nz
AppendLine();
}
adaptedShader->Visit(*this);
PushScope();
{
adaptedShader->Visit(*this);
// Append true GLSL entry point
AppendEntryPoint(shaderStage);
// Append true GLSL entry point
AppendEntryPoint(shaderStage, adaptedShader);
}
PopScope();
return state.stream.str();
}
@ -340,8 +387,12 @@ namespace Nz
AppendLine();
}
void GlslWriter::AppendEntryPoint(ShaderStageType shaderStage)
void GlslWriter::AppendEntryPoint(ShaderStageType shaderStage, ShaderAst::StatementPtr& shader)
{
EntryFuncResolver entryResolver;
entryResolver.entryPoint = m_currentState->entryFunc;
entryResolver.ScopedVisit(shader);
AppendLine();
AppendLine("// Entry point handling");
@ -354,15 +405,10 @@ namespace Nz
std::vector<InOutField> inputFields;
const ShaderAst::StructDescription* inputStruct = nullptr;
auto HandleInOutStructs = [this, shaderStage](const ShaderAst::ExpressionType& expressionType, std::vector<InOutField>& fields, const char* keyword, const char* fromPrefix, const char* targetPrefix) -> const ShaderAst::StructDescription*
auto HandleInOutStructs = [this, shaderStage](const Identifier& identifier, std::vector<InOutField>& fields, const char* keyword, const char* fromPrefix, const char* targetPrefix) -> const ShaderAst::StructDescription*
{
assert(IsIdentifierType(expressionType));
const ShaderAst::AstCache::Identifier* identifier = m_currentState->cache.FindIdentifier(0, std::get<ShaderAst::IdentifierType>(expressionType).name);
assert(identifier);
assert(std::holds_alternative<ShaderAst::StructDescription>(identifier->value));
const auto& s = std::get<ShaderAst::StructDescription>(identifier->value);
assert(std::holds_alternative<ShaderAst::StructDescription>(identifier.value));
const auto& s = std::get<ShaderAst::StructDescription>(identifier.value);
for (const auto& member : s.members)
{
@ -426,17 +472,12 @@ namespace Nz
};
if (!m_currentState->entryFunc->parameters.empty())
{
assert(m_currentState->entryFunc->parameters.size() == 1);
const auto& parameter = m_currentState->entryFunc->parameters.front();
inputStruct = HandleInOutStructs(parameter.type, inputFields, "in", "_nzInput.", "_NzIn_");
}
inputStruct = HandleInOutStructs(entryResolver.inputIdentifier, inputFields, "in", "_nzInput.", "_NzIn_");
std::vector<InOutField> outputFields;
const ShaderAst::StructDescription* outputStruct = nullptr;
if (!IsNoType(m_currentState->entryFunc->returnType))
outputStruct = HandleInOutStructs(m_currentState->entryFunc->returnType, outputFields, "out", "_nzOutput.", "_NzOut_");
outputStruct = HandleInOutStructs(entryResolver.outputIdentifier, outputFields, "out", "_nzOutput.", "_NzOut_");
if (shaderStage == ShaderStageType::Vertex && m_environment.flipYPosition)
AppendLine("uniform float ", flipYUniformName, ";");
@ -486,12 +527,12 @@ namespace Nz
LeaveScope();
}
void GlslWriter::AppendField(std::size_t scopeId, const std::string& structName, const std::string* memberIdentifier, std::size_t remainingMembers)
void GlslWriter::AppendField(const std::string& structName, const std::string* memberIdentifier, std::size_t remainingMembers)
{
Append(".");
Append(memberIdentifier[0]);
const ShaderAst::AstCache::Identifier* identifier = m_currentState->cache.FindIdentifier(scopeId, structName);
const Identifier* identifier = FindIdentifier(structName);
assert(identifier);
assert(std::holds_alternative<ShaderAst::StructDescription>(identifier->value));
@ -503,7 +544,7 @@ namespace Nz
const auto& member = *memberIt;
if (remainingMembers > 1)
AppendField(scopeId, std::get<ShaderAst::IdentifierType>(member.type).name, memberIdentifier + 1, remainingMembers - 1);
AppendField(std::get<ShaderAst::IdentifierType>(member.type).name, memberIdentifier + 1, remainingMembers - 1);
}
void GlslWriter::AppendLine(const std::string& txt)
@ -558,12 +599,10 @@ namespace Nz
{
Visit(node.structExpr, true);
const ShaderAst::ExpressionType& exprType = GetExpressionType(*node.structExpr, &m_currentState->cache);
const ShaderAst::ExpressionType& exprType = node.structExpr->cachedExpressionType.value();
assert(IsIdentifierType(exprType));
std::size_t scopeId = m_currentState->cache.GetScopeId(&node);
AppendField(scopeId, std::get<ShaderAst::IdentifierType>(exprType).name, node.memberIdentifiers.data(), node.memberIdentifiers.size());
AppendField(std::get<ShaderAst::IdentifierType>(exprType).name, node.memberIdentifiers.data(), node.memberIdentifiers.size());
}
void GlslWriter::Visit(ShaderAst::AssignExpression& node)
@ -593,7 +632,9 @@ namespace Nz
AppendLine(")");
EnterScope();
PushScope();
statement.statement->Visit(*this);
PopScope();
LeaveScope();
first = false;
@ -604,7 +645,9 @@ namespace Nz
AppendLine("else");
EnterScope();
PushScope();
node.elseStatement->Visit(*this);
PopScope();
LeaveScope();
}
}
@ -698,6 +741,8 @@ namespace Nz
void GlslWriter::Visit(ShaderAst::DeclareExternalStatement& node)
{
for (const auto& externalVar : node.externalVars)
{
std::optional<long long> bindingIndex;
@ -729,7 +774,7 @@ namespace Nz
EnterScope();
{
const ShaderAst::AstCache::Identifier* identifier = m_currentState->cache.FindIdentifier(0, std::get<ShaderAst::UniformType>(externalVar.type).containedType.name);
const Identifier* identifier = FindIdentifier(std::get<ShaderAst::UniformType>(externalVar.type).containedType.name);
assert(identifier);
assert(std::holds_alternative<ShaderAst::StructDescription>(identifier->value));
@ -780,15 +825,19 @@ namespace Nz
Append(")\n");
EnterScope();
PushScope();
{
for (auto& statement : node.statements)
statement->Visit(*this);
}
PopScope();
LeaveScope();
}
void GlslWriter::Visit(ShaderAst::DeclareStructStatement& node)
{
RegisterStruct(node.description);
Append("struct ");
AppendLine(node.description.name);
EnterScope();
@ -813,6 +862,8 @@ namespace Nz
void GlslWriter::Visit(ShaderAst::DeclareVariableStatement& node)
{
RegisterVariable(node.varName, node.varType);
Append(node.varType);
Append(" ");
Append(node.varName);
@ -871,6 +922,8 @@ namespace Nz
void GlslWriter::Visit(ShaderAst::MultiStatement& node)
{
PushScope();
bool first = true;
for (const ShaderAst::StatementPtr& statement : node.statements)
{
@ -881,6 +934,8 @@ namespace Nz
first = false;
}
PopScope();
}
void GlslWriter::Visit(ShaderAst::NoOpStatement& /*node*/)

View File

@ -42,7 +42,7 @@ namespace Nz::ShaderAst
return PopStatement();
}
std::unique_ptr<DeclareFunctionStatement> AstCloner::Clone(DeclareFunctionStatement& node)
StatementPtr AstCloner::Clone(DeclareFunctionStatement& node)
{
auto clone = std::make_unique<DeclareFunctionStatement>();
clone->attributes = node.attributes;
@ -63,6 +63,8 @@ namespace Nz::ShaderAst
clone->memberIdentifiers = node.memberIdentifiers;
clone->structExpr = CloneExpression(node.structExpr);
clone->cachedExpressionType = node.cachedExpressionType;
PushExpression(std::move(clone));
}
@ -73,6 +75,8 @@ namespace Nz::ShaderAst
clone->left = CloneExpression(node.left);
clone->right = CloneExpression(node.right);
clone->cachedExpressionType = node.cachedExpressionType;
PushExpression(std::move(clone));
}
@ -83,6 +87,8 @@ namespace Nz::ShaderAst
clone->left = CloneExpression(node.left);
clone->right = CloneExpression(node.right);
clone->cachedExpressionType = node.cachedExpressionType;
PushExpression(std::move(clone));
}
@ -100,6 +106,8 @@ namespace Nz::ShaderAst
clone->expressions[expressionCount++] = CloneExpression(expr);
}
clone->cachedExpressionType = node.cachedExpressionType;
PushExpression(std::move(clone));
}
@ -110,6 +118,8 @@ namespace Nz::ShaderAst
clone->falsePath = CloneExpression(node.falsePath);
clone->truePath = CloneExpression(node.truePath);
clone->cachedExpressionType = node.cachedExpressionType;
PushExpression(std::move(clone));
}
@ -118,6 +128,8 @@ namespace Nz::ShaderAst
auto clone = std::make_unique<ConstantExpression>();
clone->value = node.value;
clone->cachedExpressionType = node.cachedExpressionType;
PushExpression(std::move(clone));
}
@ -126,6 +138,8 @@ namespace Nz::ShaderAst
auto clone = std::make_unique<IdentifierExpression>();
clone->identifier = node.identifier;
clone->cachedExpressionType = node.cachedExpressionType;
PushExpression(std::move(clone));
}
@ -138,6 +152,8 @@ namespace Nz::ShaderAst
for (auto& parameter : node.parameters)
clone->parameters.push_back(CloneExpression(parameter));
clone->cachedExpressionType = node.cachedExpressionType;
PushExpression(std::move(clone));
}
@ -148,6 +164,8 @@ namespace Nz::ShaderAst
clone->components = node.components;
clone->expression = CloneExpression(node.expression);
clone->cachedExpressionType = node.cachedExpressionType;
PushExpression(std::move(clone));
}

View File

@ -1,258 +0,0 @@
// Copyright (C) 2020 Jérôme Leclercq
// This file is part of the "Nazara Engine - Shader generator"
// For conditions of distribution and use, see copyright notice in Config.hpp
#include <Nazara/Shader/ShaderAstExpressionType.hpp>
#include <Nazara/Shader/ShaderAstCache.hpp>
#include <optional>
#include <Nazara/Shader/Debug.hpp>
namespace Nz::ShaderAst
{
ExpressionType ExpressionTypeVisitor::GetExpressionType(Expression& expression, AstCache* cache)
{
m_cache = cache;
ExpressionType type = GetExpressionTypeInternal(expression);
m_cache = nullptr;
return type;
}
ExpressionType ExpressionTypeVisitor::GetExpressionTypeInternal(Expression& expression)
{
m_lastExpressionType.reset();
Visit(expression);
assert(m_lastExpressionType.has_value());
return std::move(*m_lastExpressionType);
}
ExpressionType ExpressionTypeVisitor::ResolveAlias(Expression& expression, ExpressionType expressionType)
{
if (IsIdentifierType(expressionType))
{
auto scopeIt = m_cache->scopeIdByNode.find(&expression);
if (scopeIt == m_cache->scopeIdByNode.end())
throw std::runtime_error("internal error");
const AstCache::Identifier* identifier = m_cache->FindIdentifier(scopeIt->second, std::get<IdentifierType>(expressionType).name);
if (identifier && std::holds_alternative<AstCache::Alias>(identifier->value))
{
const AstCache::Alias& alias = std::get<AstCache::Alias>(identifier->value);
return std::visit([&](auto&& arg) -> ShaderAst::ExpressionType
{
using T = std::decay_t<decltype(arg)>;
if constexpr (std::is_same_v<T, ExpressionType>)
return arg;
else
static_assert(AlwaysFalse<T>::value, "non-exhaustive visitor");
}, alias.value);
}
}
return expressionType;
}
void ExpressionTypeVisitor::Visit(Expression& expression)
{
if (m_cache)
{
auto it = m_cache->nodeExpressionType.find(&expression);
if (it != m_cache->nodeExpressionType.end())
{
m_lastExpressionType = it->second;
return;
}
}
expression.Visit(*this);
if (m_cache)
{
assert(m_lastExpressionType.has_value());
m_cache->nodeExpressionType.emplace(&expression, *m_lastExpressionType);
}
}
void ExpressionTypeVisitor::Visit(AccessMemberExpression& node)
{
auto scopeIt = m_cache->scopeIdByNode.find(&node);
if (scopeIt == m_cache->scopeIdByNode.end())
throw std::runtime_error("internal error");
ExpressionType expressionType = ResolveAlias(node, GetExpressionTypeInternal(*node.structExpr));
if (!IsIdentifierType(expressionType))
throw std::runtime_error("internal error");
const AstCache::Identifier* identifier = m_cache->FindIdentifier(scopeIt->second, std::get<IdentifierType>(expressionType).name);
throw std::runtime_error("unhandled accessmember expression");
}
void ExpressionTypeVisitor::Visit(AssignExpression& node)
{
Visit(*node.left);
}
void ExpressionTypeVisitor::Visit(BinaryExpression& node)
{
switch (node.op)
{
case BinaryType::Add:
case BinaryType::Subtract:
return Visit(*node.left);
case BinaryType::Divide:
case BinaryType::Multiply:
{
ExpressionType leftExprType = ResolveAlias(node, GetExpressionTypeInternal(*node.left));
ExpressionType rightExprType = ResolveAlias(node, GetExpressionTypeInternal(*node.right));
if (IsPrimitiveType(leftExprType))
{
switch (std::get<PrimitiveType>(leftExprType))
{
case PrimitiveType::Boolean:
m_lastExpressionType = std::move(leftExprType);
break;
case PrimitiveType::Float32:
case PrimitiveType::Int32:
case PrimitiveType::UInt32:
m_lastExpressionType = std::move(rightExprType);
break;
}
}
else if (IsMatrixType(leftExprType))
{
if (IsVectorType(rightExprType))
m_lastExpressionType = std::move(rightExprType);
else
m_lastExpressionType = std::move(leftExprType);
}
else if (IsVectorType(leftExprType))
m_lastExpressionType = std::move(leftExprType);
else
throw std::runtime_error("validation failure");
break;
}
case BinaryType::CompEq:
case BinaryType::CompGe:
case BinaryType::CompGt:
case BinaryType::CompLe:
case BinaryType::CompLt:
case BinaryType::CompNe:
m_lastExpressionType = PrimitiveType::Boolean;
break;
}
}
void ExpressionTypeVisitor::Visit(CastExpression& node)
{
m_lastExpressionType = node.targetType;
}
void ExpressionTypeVisitor::Visit(ConditionalExpression& node)
{
ExpressionType leftExprType = ResolveAlias(node, GetExpressionTypeInternal(*node.truePath));
assert(leftExprType == ResolveAlias(node, GetExpressionTypeInternal(*node.falsePath)));
m_lastExpressionType = std::move(leftExprType);
}
void ExpressionTypeVisitor::Visit(ConstantExpression& node)
{
m_lastExpressionType = std::visit([&](auto&& arg) -> ShaderAst::ExpressionType
{
using T = std::decay_t<decltype(arg)>;
if constexpr (std::is_same_v<T, bool>)
return PrimitiveType::Boolean;
else if constexpr (std::is_same_v<T, float>)
return PrimitiveType::Float32;
else if constexpr (std::is_same_v<T, Int32>)
return PrimitiveType::Int32;
else if constexpr (std::is_same_v<T, UInt32>)
return PrimitiveType::UInt32;
else if constexpr (std::is_same_v<T, Vector2f>)
return VectorType{ 2, PrimitiveType::Float32 };
else if constexpr (std::is_same_v<T, Vector3f>)
return VectorType{ 3, PrimitiveType::Float32 };
else if constexpr (std::is_same_v<T, Vector4f>)
return VectorType{ 4, PrimitiveType::Float32 };
else if constexpr (std::is_same_v<T, Vector2i32>)
return VectorType{ 2, PrimitiveType::Int32 };
else if constexpr (std::is_same_v<T, Vector3i32>)
return VectorType{ 3, PrimitiveType::Int32 };
else if constexpr (std::is_same_v<T, Vector4i32>)
return VectorType{ 4, PrimitiveType::Int32 };
else
static_assert(AlwaysFalse<T>::value, "non-exhaustive visitor");
}, node.value);
}
void ExpressionTypeVisitor::Visit(IdentifierExpression& node)
{
assert(m_cache);
auto scopeIt = m_cache->scopeIdByNode.find(&node);
if (scopeIt == m_cache->scopeIdByNode.end())
throw std::runtime_error("internal error");
const AstCache::Identifier* identifier = m_cache->FindIdentifier(scopeIt->second, node.identifier);
if (!identifier || !std::holds_alternative<AstCache::Variable>(identifier->value))
throw std::runtime_error("internal error");
m_lastExpressionType = ResolveAlias(node, std::get<AstCache::Variable>(identifier->value).type);
}
void ExpressionTypeVisitor::Visit(IntrinsicExpression& node)
{
switch (node.intrinsic)
{
case IntrinsicType::CrossProduct:
Visit(*node.parameters.front());
break;
case IntrinsicType::DotProduct:
m_lastExpressionType = PrimitiveType::Float32;
break;
case IntrinsicType::SampleTexture:
{
if (node.parameters.empty())
throw std::runtime_error("validation failure");
ExpressionType firstParamType = ResolveAlias(node, GetExpressionTypeInternal(*node.parameters.front()));
if (!IsSamplerType(firstParamType))
throw std::runtime_error("validation failure");
const auto& sampler = std::get<SamplerType>(firstParamType);
m_lastExpressionType = VectorType{
4,
sampler.sampledType
};
break;
}
}
}
void ExpressionTypeVisitor::Visit(SwizzleExpression& node)
{
ExpressionType exprType = GetExpressionTypeInternal(*node.expression);
if (IsMatrixType(exprType))
m_lastExpressionType = std::get<MatrixType>(exprType).type;
else if (IsVectorType(exprType))
m_lastExpressionType = std::get<VectorType>(exprType).type;
else
throw std::runtime_error("validation failure");
}
}

View File

@ -4,7 +4,6 @@
#include <Nazara/Shader/ShaderAstOptimizer.hpp>
#include <Nazara/Shader/ShaderBuilder.hpp>
#include <Nazara/Shader/ShaderAstExpressionType.hpp>
#include <cassert>
#include <stdexcept>
#include <Nazara/Shader/Debug.hpp>
@ -453,8 +452,11 @@ namespace Nz::ShaderAst
{
auto& constant = static_cast<ConstantExpression&>(*cond);
assert(IsPrimitiveType(GetExpressionType(constant)));
assert(std::get<PrimitiveType>(GetExpressionType(constant)) == PrimitiveType::Boolean);
assert(constant.cachedExpressionType);
const ExpressionType& constantType = constant.cachedExpressionType.value();
assert(IsPrimitiveType(constantType));
assert(std::get<PrimitiveType>(constantType) == PrimitiveType::Boolean);
bool cValue = std::get<bool>(constant.value);
if (!cValue)

View File

@ -0,0 +1,110 @@
// Copyright (C) 2020 Jérôme Leclercq
// This file is part of the "Nazara Engine - Shader generator"
// For conditions of distribution and use, see copyright notice in Config.hpp
#include <Nazara/Shader/ShaderAstScopedVisitor.hpp>
#include <Nazara/Shader/Debug.hpp>
namespace Nz::ShaderAst
{
void AstScopedVisitor::ScopedVisit(StatementPtr& nodePtr)
{
PushScope(); //< Global scope
{
nodePtr->Visit(*this);
}
PopScope();
}
void AstScopedVisitor::Visit(BranchStatement& node)
{
for (auto& cond : node.condStatements)
{
PushScope();
{
cond.condition->Visit(*this);
cond.statement->Visit(*this);
}
PopScope();
}
if (node.elseStatement)
{
PushScope();
{
node.elseStatement->Visit(*this);
}
PopScope();
}
}
void AstScopedVisitor::Visit(ConditionalStatement& node)
{
PushScope();
{
AstRecursiveVisitor::Visit(node);
}
PopScope();
}
void AstScopedVisitor::Visit(DeclareExternalStatement& node)
{
for (auto& extVar : node.externalVars)
{
ExpressionType subType = extVar.type;
if (IsUniformType(subType))
subType = IdentifierType{ std::get<UniformType>(subType).containedType };
RegisterVariable(extVar.name, std::move(subType));
}
AstRecursiveVisitor::Visit(node);
}
void AstScopedVisitor::Visit(DeclareFunctionStatement& node)
{
PushScope();
{
for (auto& parameter : node.parameters)
RegisterVariable(parameter.name, parameter.type);
AstRecursiveVisitor::Visit(node);
}
PopScope();
}
void AstScopedVisitor::Visit(DeclareStructStatement& node)
{
RegisterStruct(node.description);
AstRecursiveVisitor::Visit(node);
}
void AstScopedVisitor::Visit(DeclareVariableStatement& node)
{
RegisterVariable(node.varName, node.varType);
AstRecursiveVisitor::Visit(node);
}
void AstScopedVisitor::Visit(MultiStatement& node)
{
PushScope();
{
AstRecursiveVisitor::Visit(node);
}
PopScope();
}
void AstScopedVisitor::PushScope()
{
m_scopeSizes.push_back(m_identifiersInScope.size());
}
void AstScopedVisitor::PopScope()
{
assert(!m_scopeSizes.empty());
m_identifiersInScope.resize(m_scopeSizes.back());
m_scopeSizes.pop_back();
}
}

View File

@ -5,7 +5,6 @@
#include <Nazara/Shader/ShaderAstValidator.hpp>
#include <Nazara/Core/CallOnExit.hpp>
#include <Nazara/Shader/ShaderAstUtils.hpp>
#include <Nazara/Shader/ShaderAstExpressionType.hpp>
#include <unordered_set>
#include <vector>
#include <Nazara/Shader/Debug.hpp>
@ -27,29 +26,21 @@ namespace Nz::ShaderAst
struct AstValidator::Context
{
//const ShaderAst::Function* currentFunction;
std::optional<std::size_t> activeScopeId;
std::array<DeclareFunctionStatement*, ShaderStageTypeCount> entryFunctions = {};
std::unordered_set<std::string> declaredExternalVar;
std::unordered_set<long long> usedBindingIndexes;;
AstCache* cache;
std::unordered_set<long long> usedBindingIndexes;
};
bool AstValidator::Validate(StatementPtr& node, std::string* error, AstCache* cache)
bool AstValidator::Validate(StatementPtr& node, std::string* error)
{
try
{
AstCache dummy;
Context currentContext;
currentContext.cache = (cache) ? cache : &dummy;
m_context = &currentContext;
CallOnExit resetContext([&] { m_context = nullptr; });
EnterScope();
node->Visit(*this);
ExitScope();
ScopedVisit(node);
return true;
}
catch (const AstError& e)
@ -61,6 +52,12 @@ namespace Nz::ShaderAst
}
}
const ExpressionType& AstValidator::GetExpressionType(Expression& expression)
{
assert(expression.cachedExpressionType);
return ResolveAlias(expression.cachedExpressionType.value());
}
Expression& AstValidator::MandatoryExpr(ExpressionPtr& node)
{
if (!node)
@ -79,7 +76,7 @@ namespace Nz::ShaderAst
void AstValidator::TypeMustMatch(ExpressionPtr& left, ExpressionPtr& right)
{
return TypeMustMatch(GetExpressionType(*left, m_context->cache), GetExpressionType(*right, m_context->cache));
return TypeMustMatch(GetExpressionType(*left), GetExpressionType(*right));
}
void AstValidator::TypeMustMatch(const ExpressionType& left, const ExpressionType& right)
@ -90,7 +87,7 @@ namespace Nz::ShaderAst
ExpressionType AstValidator::CheckField(const std::string& structName, const std::string* memberIdentifier, std::size_t remainingMembers)
{
const AstCache::Identifier* identifier = m_context->cache->FindIdentifier(*m_context->activeScopeId, structName);
const Identifier* identifier = FindIdentifier(structName);
if (!identifier)
throw AstError{ "unknown identifier " + structName };
@ -111,81 +108,69 @@ namespace Nz::ShaderAst
return member.type;
}
AstCache::Scope& AstValidator::EnterScope()
const ExpressionType& AstValidator::ResolveAlias(const ExpressionType& expressionType)
{
std::size_t newScopeId = m_context->cache->scopes.size();
if (!IsIdentifierType(expressionType))
return expressionType;
std::optional<std::size_t> previousScope = m_context->activeScopeId;
const Identifier* identifier = FindIdentifier(std::get<IdentifierType>(expressionType).name);
if (identifier && std::holds_alternative<Alias>(identifier->value))
{
const Alias& alias = std::get<Alias>(identifier->value);
return std::visit([&](auto&& arg) -> const ShaderAst::ExpressionType&
{
using T = std::decay_t<decltype(arg)>;
auto& newScope = m_context->cache->scopes.emplace_back();
newScope.parentScopeIndex = previousScope;
if constexpr (std::is_same_v<T, ExpressionType>)
return arg;
else
static_assert(AlwaysFalse<T>::value, "non-exhaustive visitor");
}, alias.value);
}
m_context->activeScopeId = newScopeId;
return m_context->cache->scopes[newScopeId];
}
void AstValidator::ExitScope()
{
assert(m_context->activeScopeId);
auto& previousScope = m_context->cache->scopes[*m_context->activeScopeId];
m_context->activeScopeId = previousScope.parentScopeIndex;
}
void AstValidator::RegisterExpressionType(Expression& node, ExpressionType expressionType)
{
m_context->cache->nodeExpressionType[&node] = std::move(expressionType);
}
void AstValidator::RegisterScope(Node& node)
{
if (m_context->activeScopeId)
m_context->cache->scopeIdByNode[&node] = *m_context->activeScopeId;
return expressionType;
}
void AstValidator::Visit(AccessMemberExpression& node)
{
RegisterScope(node);
// Register expressions types
AstRecursiveVisitor::Visit(node);
AstScopedVisitor::Visit(node);
ExpressionType exprType = GetExpressionType(MandatoryExpr(node.structExpr), m_context->cache);
ExpressionType exprType = GetExpressionType(MandatoryExpr(node.structExpr));
if (!IsIdentifierType(exprType))
throw AstError{ "expression is not a structure" };
const std::string& structName = std::get<IdentifierType>(exprType).name;
RegisterExpressionType(node, CheckField(structName, node.memberIdentifiers.data(), node.memberIdentifiers.size()));
node.cachedExpressionType = CheckField(structName, node.memberIdentifiers.data(), node.memberIdentifiers.size());
}
void AstValidator::Visit(AssignExpression& node)
{
RegisterScope(node);
MandatoryExpr(node.left);
MandatoryExpr(node.right);
// Register expressions types
AstRecursiveVisitor::Visit(node);
AstScopedVisitor::Visit(node);
TypeMustMatch(node.left, node.right);
if (GetExpressionCategory(*node.left) != ExpressionCategory::LValue)
throw AstError { "Assignation is only possible with a l-value" };
node.cachedExpressionType = GetExpressionType(*node.right);
}
void AstValidator::Visit(BinaryExpression& node)
{
RegisterScope(node);
// Register expression type
AstRecursiveVisitor::Visit(node);
AstScopedVisitor::Visit(node);
ExpressionType leftExprType = GetExpressionType(MandatoryExpr(node.left), m_context->cache);
ExpressionType leftExprType = GetExpressionType(MandatoryExpr(node.left));
if (!IsPrimitiveType(leftExprType) && !IsMatrixType(leftExprType) && !IsVectorType(leftExprType))
throw AstError{ "left expression type does not support binary operation" };
ExpressionType rightExprType = GetExpressionType(MandatoryExpr(node.right), m_context->cache);
ExpressionType rightExprType = GetExpressionType(MandatoryExpr(node.right));
if (!IsPrimitiveType(rightExprType) && !IsMatrixType(rightExprType) && !IsVectorType(rightExprType))
throw AstError{ "right expression type does not support binary operation" };
@ -201,12 +186,18 @@ namespace Nz::ShaderAst
if (leftType == PrimitiveType::Boolean)
throw AstError{ "this operation is not supported for booleans" };
[[fallthrough]];
TypeMustMatch(node.left, node.right);
node.cachedExpressionType = PrimitiveType::Boolean;
break;
case BinaryType::Add:
case BinaryType::CompEq:
case BinaryType::CompNe:
case BinaryType::Subtract:
TypeMustMatch(node.left, node.right);
node.cachedExpressionType = leftExprType;
break;
case BinaryType::Multiply:
@ -219,9 +210,20 @@ namespace Nz::ShaderAst
case PrimitiveType::UInt32:
{
if (IsMatrixType(rightExprType))
{
TypeMustMatch(leftType, std::get<MatrixType>(rightExprType).type);
node.cachedExpressionType = rightExprType;
}
else if (IsPrimitiveType(rightExprType))
{
TypeMustMatch(leftType, rightExprType);
node.cachedExpressionType = leftExprType;
}
else if (IsVectorType(rightExprType))
{
TypeMustMatch(leftType, std::get<VectorType>(rightExprType).type);
node.cachedExpressionType = rightExprType;
}
else
throw AstError{ "incompatible types" };
@ -248,18 +250,29 @@ namespace Nz::ShaderAst
case BinaryType::CompLt:
case BinaryType::CompEq:
case BinaryType::CompNe:
TypeMustMatch(node.left, node.right);
node.cachedExpressionType = PrimitiveType::Boolean;
break;
case BinaryType::Add:
case BinaryType::Subtract:
TypeMustMatch(node.left, node.right);
node.cachedExpressionType = leftExprType;
break;
case BinaryType::Multiply:
case BinaryType::Divide:
{
if (IsMatrixType(rightExprType))
{
TypeMustMatch(leftExprType, rightExprType);
node.cachedExpressionType = leftExprType; //< FIXME
}
else if (IsPrimitiveType(rightExprType))
{
TypeMustMatch(leftType.type, rightExprType);
node.cachedExpressionType = leftExprType;
}
else if (IsVectorType(rightExprType))
{
const VectorType& rightType = std::get<VectorType>(rightExprType);
@ -267,6 +280,8 @@ namespace Nz::ShaderAst
if (leftType.columnCount != rightType.componentCount)
throw AstError{ "incompatible types" };
node.cachedExpressionType = rightExprType;
}
else
throw AstError{ "incompatible types" };
@ -275,7 +290,7 @@ namespace Nz::ShaderAst
}
else if (IsVectorType(leftExprType))
{
const MatrixType& leftType = std::get<MatrixType>(leftExprType);
const VectorType& leftType = std::get<VectorType>(leftExprType);
switch (node.op)
{
case BinaryType::CompGe:
@ -284,16 +299,29 @@ namespace Nz::ShaderAst
case BinaryType::CompLt:
case BinaryType::CompEq:
case BinaryType::CompNe:
TypeMustMatch(node.left, node.right);
node.cachedExpressionType = PrimitiveType::Boolean;
break;
case BinaryType::Add:
case BinaryType::Subtract:
TypeMustMatch(node.left, node.right);
node.cachedExpressionType = leftExprType;
break;
case BinaryType::Multiply:
case BinaryType::Divide:
{
if (IsPrimitiveType(rightExprType))
{
TypeMustMatch(leftType.type, rightExprType);
node.cachedExpressionType = rightExprType;
}
else if (IsVectorType(rightExprType))
{
TypeMustMatch(leftType, rightExprType);
node.cachedExpressionType = rightExprType;
}
else
throw AstError{ "incompatible types" };
}
@ -303,11 +331,9 @@ namespace Nz::ShaderAst
void AstValidator::Visit(CastExpression& node)
{
RegisterScope(node);
AstScopedVisitor::Visit(node);
AstRecursiveVisitor::Visit(node);
auto GetComponentCount = [](const ExpressionType& exprType) -> unsigned int
auto GetComponentCount = [](const ExpressionType& exprType) -> std::size_t
{
if (IsPrimitiveType(exprType))
return 1;
@ -317,15 +343,15 @@ namespace Nz::ShaderAst
throw AstError{ "wut" };
};
unsigned int componentCount = 0;
unsigned int requiredComponents = GetComponentCount(node.targetType);
std::size_t componentCount = 0;
std::size_t requiredComponents = GetComponentCount(node.targetType);
for (auto& exprPtr : node.expressions)
{
if (!exprPtr)
break;
ExpressionType exprType = GetExpressionType(*exprPtr, m_context->cache);
ExpressionType exprType = GetExpressionType(*exprPtr);
if (!IsPrimitiveType(exprType) && !IsVectorType(exprType))
throw AstError{ "incompatible type" };
@ -334,6 +360,40 @@ namespace Nz::ShaderAst
if (componentCount != requiredComponents)
throw AstError{ "component count doesn't match required component count" };
node.cachedExpressionType = node.targetType;
}
void AstValidator::Visit(ConstantExpression& node)
{
node.cachedExpressionType = std::visit([&](auto&& arg) -> ShaderAst::ExpressionType
{
using T = std::decay_t<decltype(arg)>;
if constexpr (std::is_same_v<T, bool>)
return PrimitiveType::Boolean;
else if constexpr (std::is_same_v<T, float>)
return PrimitiveType::Float32;
else if constexpr (std::is_same_v<T, Int32>)
return PrimitiveType::Int32;
else if constexpr (std::is_same_v<T, UInt32>)
return PrimitiveType::UInt32;
else if constexpr (std::is_same_v<T, Vector2f>)
return VectorType{ 2, PrimitiveType::Float32 };
else if constexpr (std::is_same_v<T, Vector3f>)
return VectorType{ 3, PrimitiveType::Float32 };
else if constexpr (std::is_same_v<T, Vector4f>)
return VectorType{ 4, PrimitiveType::Float32 };
else if constexpr (std::is_same_v<T, Vector2i32>)
return VectorType{ 2, PrimitiveType::Int32 };
else if constexpr (std::is_same_v<T, Vector3i32>)
return VectorType{ 3, PrimitiveType::Int32 };
else if constexpr (std::is_same_v<T, Vector4i32>)
return VectorType{ 4, PrimitiveType::Int32 };
else
static_assert(AlwaysFalse<T>::value, "non-exhaustive visitor");
}, node.value);
}
void AstValidator::Visit(ConditionalExpression& node)
@ -341,37 +401,31 @@ namespace Nz::ShaderAst
MandatoryExpr(node.truePath);
MandatoryExpr(node.falsePath);
RegisterScope(node);
AstScopedVisitor::Visit(node);
AstRecursiveVisitor::Visit(node);
ExpressionType leftExprType = GetExpressionType(*node.truePath);
if (leftExprType != GetExpressionType(*node.falsePath))
throw AstError{ "true path type must match false path type" };
node.cachedExpressionType = leftExprType;
//if (m_shader.FindConditionByName(node.conditionName) == ShaderAst::InvalidCondition)
// throw AstError{ "condition not found" };
}
void AstValidator::Visit(ConstantExpression& node)
{
RegisterScope(node);
}
void AstValidator::Visit(IdentifierExpression& node)
{
assert(m_context);
if (!m_context->activeScopeId)
throw AstError{ "no scope" };
RegisterScope(node);
const AstCache::Identifier* identifier = m_context->cache->FindIdentifier(*m_context->activeScopeId, node.identifier);
const Identifier* identifier = FindIdentifier(node.identifier);
if (!identifier)
throw AstError{ "Unknown variable " + node.identifier };
throw AstError{ "Unknown identifier " + node.identifier };
node.cachedExpressionType = ResolveAlias(std::get<Variable>(identifier->value).type);
}
void AstValidator::Visit(IntrinsicExpression& node)
{
RegisterScope(node);
AstRecursiveVisitor::Visit(node);
AstScopedVisitor::Visit(node);
switch (node.intrinsic)
{
@ -384,10 +438,11 @@ namespace Nz::ShaderAst
for (auto& param : node.parameters)
MandatoryExpr(param);
ExpressionType type = GetExpressionType(*node.parameters.front(), m_context->cache);
ExpressionType type = GetExpressionType(*node.parameters.front());
for (std::size_t i = 1; i < node.parameters.size(); ++i)
{
if (type != GetExpressionType(MandatoryExpr(node.parameters[i]), m_context->cache))
if (type != GetExpressionType(MandatoryExpr(node.parameters[i])))
throw AstError{ "All type must match" };
}
@ -402,11 +457,11 @@ namespace Nz::ShaderAst
for (auto& param : node.parameters)
MandatoryExpr(param);
if (!IsSamplerType(GetExpressionType(*node.parameters[0], m_context->cache)))
if (!IsSamplerType(GetExpressionType(*node.parameters[0])))
throw AstError{ "First parameter must be a sampler" };
if (!IsVectorType(GetExpressionType(*node.parameters[1], m_context->cache)))
throw AstError{ "First parameter must be a vector" };
if (!IsVectorType(GetExpressionType(*node.parameters[1])))
throw AstError{ "Second parameter must be a vector" };
}
}
@ -414,63 +469,89 @@ namespace Nz::ShaderAst
{
case IntrinsicType::CrossProduct:
{
if (GetExpressionType(*node.parameters[0]) != ExpressionType{ VectorType{ 3, PrimitiveType::Float32 } })
ExpressionType type = GetExpressionType(*node.parameters.front());
if (type != ExpressionType{ VectorType{ 3, PrimitiveType::Float32 } })
throw AstError{ "CrossProduct only works with vec3<f32> expressions" };
node.cachedExpressionType = std::move(type);
break;
}
case IntrinsicType::DotProduct:
{
ExpressionType type = GetExpressionType(*node.parameters.front());
if (!IsVectorType(type))
throw AstError{ "DotProduct expects vector types" };
node.cachedExpressionType = std::get<VectorType>(type).type;
break;
}
case IntrinsicType::SampleTexture:
{
node.cachedExpressionType = VectorType{ 4, std::get<SamplerType>(GetExpressionType(*node.parameters.front())).sampledType };
break;
}
}
}
void AstValidator::Visit(SwizzleExpression& node)
{
RegisterScope(node);
if (node.componentCount > 4)
throw AstError{ "Cannot swizzle more than four elements" };
ExpressionType exprType = GetExpressionType(MandatoryExpr(node.expression), m_context->cache);
if (!IsPrimitiveType(exprType) && !IsVectorType(exprType))
throw AstError{ "Cannot swizzle this type" };
MandatoryExpr(node.expression);
AstRecursiveVisitor::Visit(node);
AstScopedVisitor::Visit(node);
ExpressionType exprType = GetExpressionType(*node.expression);
if (IsPrimitiveType(exprType) || IsVectorType(exprType))
{
PrimitiveType baseType;
if (IsPrimitiveType(exprType))
baseType = std::get<PrimitiveType>(exprType);
else
baseType = std::get<VectorType>(exprType).type;
if (node.componentCount > 1)
{
node.cachedExpressionType = VectorType{
node.componentCount,
baseType
};
}
else
node.cachedExpressionType = baseType;
}
else
throw AstError{ "Cannot swizzle this type" };
}
void AstValidator::Visit(BranchStatement& node)
{
RegisterScope(node);
for (auto& condStatement : node.condStatements)
{
ExpressionType condType = GetExpressionType(MandatoryExpr(condStatement.condition), m_context->cache);
ExpressionType condType = GetExpressionType(MandatoryExpr(condStatement.condition));
if (!IsPrimitiveType(condType) || std::get<PrimitiveType>(condType) != PrimitiveType::Boolean)
throw AstError{ "if expression must resolve to boolean type" };
MandatoryStatement(condStatement.statement);
}
AstRecursiveVisitor::Visit(node);
AstScopedVisitor::Visit(node);
}
void AstValidator::Visit(ConditionalStatement& node)
{
MandatoryStatement(node.statement);
RegisterScope(node);
AstRecursiveVisitor::Visit(node);
AstScopedVisitor::Visit(node);
//if (m_shader.FindConditionByName(node.conditionName) == ShaderAst::InvalidCondition)
// throw AstError{ "condition not found" };
}
void AstValidator::Visit(DeclareExternalStatement& node)
{
RegisterScope(node);
auto& scope = m_context->cache->scopes[*m_context->activeScopeId];
for (const auto& [attributeType, arg] : node.attributes)
{
switch (attributeType)
@ -513,7 +594,7 @@ namespace Nz::ShaderAst
throw AstError{ "attribute layout requires a string parameter" };
if (std::get<std::string>(arg) != "std140")
throw AstError{ "unknow layout type" };
throw AstError{ "unknown layout type" };
hasLayout = true;
break;
@ -528,14 +609,9 @@ namespace Nz::ShaderAst
throw AstError{ "External variable " + extVar.name + " is already declared" };
m_context->declaredExternalVar.insert(extVar.name);
ExpressionType subType = extVar.type;
if (IsUniformType(subType))
subType = IdentifierType{ std::get<UniformType>(subType).containedType };
auto& identifier = scope.identifiers.emplace_back();
identifier = AstCache::Identifier{ extVar.name, AstCache::Variable { std::move(subType) } };
}
AstScopedVisitor::Visit(node);
}
void AstValidator::Visit(DeclareFunctionStatement& node)
@ -561,10 +637,10 @@ namespace Nz::ShaderAst
ShaderStageType stageType = it->second;
if (m_context->cache->entryFunctions[UnderlyingCast(stageType)])
if (m_context->entryFunctions[UnderlyingCast(stageType)])
throw AstError{ "the same entry type has been defined multiple times" };
m_context->cache->entryFunctions[UnderlyingCast(it->second)] = &node;
m_context->entryFunctions[UnderlyingCast(it->second)] = &node;
if (node.parameters.size() > 1)
throw AstError{ "entry functions can either take one struct parameter or no parameter" };
@ -578,103 +654,41 @@ namespace Nz::ShaderAst
}
}
auto& scope = EnterScope();
RegisterScope(node);
for (auto& parameter : node.parameters)
{
auto& identifier = scope.identifiers.emplace_back();
identifier = AstCache::Identifier{ parameter.name, AstCache::Variable { parameter.type } };
}
for (auto& statement : node.statements)
MandatoryStatement(statement).Visit(*this);
MandatoryStatement(statement);
ExitScope();
AstScopedVisitor::Visit(node);
}
void AstValidator::Visit(DeclareStructStatement& node)
{
assert(m_context);
if (!m_context->activeScopeId)
throw AstError{ "cannot declare variable without scope" };
RegisterScope(node);
//TODO: check members attributes
auto& scope = m_context->cache->scopes[*m_context->activeScopeId];
auto& identifier = scope.identifiers.emplace_back();
identifier = AstCache::Identifier{ node.description.name, node.description };
AstRecursiveVisitor::Visit(node);
}
void AstValidator::Visit(DeclareVariableStatement& node)
{
assert(m_context);
if (!m_context->activeScopeId)
throw AstError{ "cannot declare variable without scope" };
RegisterScope(node);
auto& scope = m_context->cache->scopes[*m_context->activeScopeId];
auto& identifier = scope.identifiers.emplace_back();
identifier = AstCache::Identifier{ node.varName, AstCache::Variable { node.varType } };
AstRecursiveVisitor::Visit(node);
AstScopedVisitor::Visit(node);
}
void AstValidator::Visit(ExpressionStatement& node)
{
RegisterScope(node);
MandatoryExpr(node.expression);
AstRecursiveVisitor::Visit(node);
AstScopedVisitor::Visit(node);
}
void AstValidator::Visit(MultiStatement& node)
{
assert(m_context);
EnterScope();
RegisterScope(node);
for (auto& statement : node.statements)
MandatoryStatement(statement);
ExitScope();
AstRecursiveVisitor::Visit(node);
AstScopedVisitor::Visit(node);
}
void AstValidator::Visit(ReturnStatement& node)
{
RegisterScope(node);
/*if (m_context->currentFunction->returnType != ShaderExpressionType(BasicType::Void))
{
if (GetExpressionType(MandatoryExpr(node.returnExpr)) != m_context->currentFunction->returnType)
throw AstError{ "Return type doesn't match function return type" };
}
else
{
if (node.returnExpr)
throw AstError{ "Unexpected expression for return (function doesn't return)" };
}*/
AstRecursiveVisitor::Visit(node);
}
bool ValidateAst(StatementPtr& node, std::string* error, AstCache* cache)
bool ValidateAst(StatementPtr& node, std::string* error)
{
AstValidator validator;
return validator.Validate(node, error, cache);
return validator.Validate(node, error);
}
}

View File

@ -4,7 +4,6 @@
#include <Nazara/Shader/SpirvAstVisitor.hpp>
#include <Nazara/Core/StackVector.hpp>
#include <Nazara/Shader/ShaderAstExpressionType.hpp>
#include <Nazara/Shader/SpirvSection.hpp>
#include <Nazara/Shader/SpirvExpressionLoad.hpp>
#include <Nazara/Shader/SpirvExpressionStore.hpp>
@ -39,13 +38,13 @@ namespace Nz
void SpirvAstVisitor::Visit(ShaderAst::BinaryExpression& node)
{
ShaderAst::ExpressionType resultExprType = ShaderAst::GetExpressionType(node, m_cache);
ShaderAst::ExpressionType resultExprType = GetExpressionType(node);
assert(IsPrimitiveType(resultExprType));
ShaderAst::ExpressionType leftExprType = ShaderAst::GetExpressionType(*node.left, m_cache);
ShaderAst::ExpressionType leftExprType = GetExpressionType(*node.left);
assert(IsPrimitiveType(leftExprType));
ShaderAst::ExpressionType rightExprType = ShaderAst::GetExpressionType(*node.right, m_cache);
ShaderAst::ExpressionType rightExprType = GetExpressionType(*node.right);
assert(IsPrimitiveType(rightExprType));
ShaderAst::PrimitiveType resultType = std::get<ShaderAst::PrimitiveType>(resultExprType);
@ -582,7 +581,7 @@ namespace Nz
{
case ShaderAst::IntrinsicType::DotProduct:
{
ShaderAst::ExpressionType vecExprType = GetExpressionType(*node.parameters[0], m_cache);
ShaderAst::ExpressionType vecExprType = GetExpressionType(*node.parameters[0]);
assert(IsVectorType(vecExprType));
const ShaderAst::VectorType& vecType = std::get<ShaderAst::VectorType>(vecExprType);
@ -626,7 +625,7 @@ namespace Nz
void SpirvAstVisitor::Visit(ShaderAst::SwizzleExpression& node)
{
ShaderAst::ExpressionType targetExprType = ShaderAst::GetExpressionType(node, m_cache);
ShaderAst::ExpressionType targetExprType = GetExpressionType(node);
assert(IsPrimitiveType(targetExprType));
ShaderAst::PrimitiveType targetType = std::get<ShaderAst::PrimitiveType>(targetExprType);

View File

@ -50,6 +50,11 @@ namespace Nz
return Compare(lhs.parameters, rhs.parameters) && Compare(lhs.returnType, rhs.returnType);
}
bool Compare(const Identifier& lhs, const Identifier& rhs) const
{
return lhs.name == rhs.name;
}
bool Compare(const Image& lhs, const Image& rhs) const
{
return lhs.arrayed == rhs.arrayed
@ -226,6 +231,11 @@ namespace Nz
void Register(const Integer&) {}
void Register(const Void&) {}
void Register(const Identifier& identifier)
{
Register(identifier);
}
void Register(const Image& image)
{
Register(image.sampledType);
@ -456,6 +466,11 @@ namespace Nz
UInt32 SpirvConstantCache::Register(Type t)
{
AnyType& type = t.type;
if (std::holds_alternative<Identifier>(type))
{
assert(m_identifierCallback);
return Register(*m_identifierCallback(std::get<Identifier>(type).name));
}
DepRegisterer registerer(*this);
registerer.Register(type);
@ -487,6 +502,11 @@ namespace Nz
return it.value();
}
void SpirvConstantCache::SetIdentifierCallback(IdentifierCallback callback)
{
m_identifierCallback = std::move(callback);
}
void SpirvConstantCache::Write(SpirvSection& annotations, SpirvSection& constants, SpirvSection& debugInfos)
{
for (auto&& [object, id] : m_internal->ids)
@ -597,7 +617,7 @@ namespace Nz
return std::make_shared<Type>(Pointer{
BuildType(type),
storageClass
});
});
}
auto SpirvConstantCache::BuildType(const ShaderAst::ExpressionType& type) -> TypePtr
@ -605,37 +625,16 @@ namespace Nz
return std::visit([&](auto&& arg) -> TypePtr
{
return BuildType(arg);
/*else if constexpr (std::is_same_v<T, std::string>)
{
// Register struct members type
const auto& structs = shader.GetStructs();
auto it = std::find_if(structs.begin(), structs.end(), [&](const auto& s) { return s.name == arg; });
if (it == structs.end())
throw std::runtime_error("struct " + arg + " has not been defined");
const ShaderAst::Struct& s = *it;
Structure sType;
sType.name = s.name;
for (const auto& member : s.members)
{
auto& sMembers = sType.members.emplace_back();
sMembers.name = member.name;
sMembers.type = BuildType(shader, member.type);
}
return std::make_shared<Type>(std::move(sType));
return nullptr;
}
else
static_assert(AlwaysFalse<T>::value, "non-exhaustive visitor");*/
}, type);
}
auto SpirvConstantCache::BuildType(const ShaderAst::IdentifierType& type) -> TypePtr
{
throw std::runtime_error("unexpected type");
return std::make_shared<Type>(
Identifier{
type.name
}
);
}
auto SpirvConstantCache::BuildType(const ShaderAst::PrimitiveType& type) -> TypePtr
@ -691,6 +690,21 @@ namespace Nz
return std::make_shared<Type>(SampledImage{ std::make_shared<Type>(imageType) });
}
auto SpirvConstantCache::BuildType(const ShaderAst::StructDescription& structDesc) -> TypePtr
{
Structure sType;
sType.name = structDesc.name;
for (const auto& member : structDesc.members)
{
auto& sMembers = sType.members.emplace_back();
sMembers.name = member.name;
sMembers.type = BuildType(member.type);
}
return std::make_shared<Type>(std::move(sType));
}
auto SpirvConstantCache::BuildType(const ShaderAst::VectorType& type) -> TypePtr
{
return std::make_shared<Type>(Vector{ BuildType(type.type), UInt32(type.componentCount) });
@ -767,6 +781,8 @@ namespace Nz
appender(GetId(*param));
});
}
else if constexpr (std::is_same_v<T, Identifier>)
throw std::runtime_error("unexpected identifier");
else if constexpr (std::is_same_v<T, Image>)
{
UInt32 depth;
@ -915,6 +931,8 @@ namespace Nz
}
else if constexpr (std::is_same_v<T, Function>)
throw std::runtime_error("unexpected function as struct member");
else if constexpr (std::is_same_v<T, Identifier>)
throw std::runtime_error("unexpected identifier");
else if constexpr (std::is_same_v<T, Image> || std::is_same_v<T, SampledImage>)
throw std::runtime_error("unexpected opaque type as struct member");
else if constexpr (std::is_same_v<T, Void>)

View File

@ -25,18 +25,26 @@ namespace Nz
{
namespace
{
class PreVisitor : public ShaderAst::AstRecursiveVisitor
class PreVisitor : public ShaderAst::AstScopedVisitor
{
public:
using ExtInstList = std::unordered_set<std::string>;
using LocalContainer = std::unordered_set<ShaderAst::ExpressionType>;
using FunctionContainer = std::vector<std::reference_wrapper<ShaderAst::DeclareFunctionStatement>>;
PreVisitor(ShaderAst::AstCache* cache, const SpirvWriter::States& conditions, SpirvConstantCache& constantCache) :
m_cache(cache),
PreVisitor(const SpirvWriter::States& conditions, SpirvConstantCache& constantCache) :
m_conditions(conditions),
m_constantCache(constantCache)
{
m_constantCache.SetIdentifierCallback([&](const std::string& identifierName)
{
const Identifier* identifier = FindIdentifier(identifierName);
if (!identifier)
throw std::runtime_error("invalid identifier " + identifierName);
assert(std::holds_alternative<ShaderAst::StructDescription>(identifier->value));
return SpirvConstantCache::BuildType(std::get<ShaderAst::StructDescription>(identifier->value));
});
}
void Visit(ShaderAst::AccessMemberExpression& node) override
@ -74,7 +82,7 @@ namespace Nz
m_constantCache.Register(*SpirvConstantCache::BuildConstant(arg));
}, node.value);
AstRecursiveVisitor::Visit(node);
AstScopedVisitor::Visit(node);
}
void Visit(ShaderAst::DeclareFunctionStatement& node) override
@ -87,11 +95,13 @@ namespace Nz
m_constantCache.Register(*SpirvConstantCache::BuildFunctionType(node.returnType, parameterTypes));
AstRecursiveVisitor::Visit(node);
AstScopedVisitor::Visit(node);
}
void Visit(ShaderAst::DeclareStructStatement& node) override
{
AstScopedVisitor::Visit(node);
SpirvConstantCache::Structure sType;
sType.name = node.description.name;
@ -107,21 +117,21 @@ namespace Nz
void Visit(ShaderAst::DeclareVariableStatement& node) override
{
m_constantCache.Register(*SpirvConstantCache::BuildType(node.varType));
AstScopedVisitor::Visit(node);
AstRecursiveVisitor::Visit(node);
m_constantCache.Register(*SpirvConstantCache::BuildType(node.varType));
}
void Visit(ShaderAst::IdentifierExpression& node) override
{
m_constantCache.Register(*SpirvConstantCache::BuildType(GetExpressionType(node, m_cache)));
m_constantCache.Register(*SpirvConstantCache::BuildType(node.cachedExpressionType.value()));
AstRecursiveVisitor::Visit(node);
AstScopedVisitor::Visit(node);
}
void Visit(ShaderAst::IntrinsicExpression& node) override
{
AstRecursiveVisitor::Visit(node);
AstScopedVisitor::Visit(node);
switch (node.intrinsic)
{
@ -140,7 +150,6 @@ namespace Nz
FunctionContainer funcs;
private:
ShaderAst::AstCache* m_cache;
const SpirvWriter::States& m_conditions;
SpirvConstantCache& m_constantCache;
};
@ -214,7 +223,7 @@ namespace Nz
std::vector<UInt32> SpirvWriter::Generate(ShaderAst::StatementPtr& shader, const States& conditions)
{
std::string error;
if (!ShaderAst::ValidateAst(shader, &error, &m_context.cache))
if (!ShaderAst::ValidateAst(shader, &error))
throw std::runtime_error("Invalid shader AST: " + error);
m_context.states = &conditions;
@ -229,7 +238,7 @@ namespace Nz
ShaderAst::AstCloner cloner;
// Register all extended instruction sets
PreVisitor preVisitor(&m_context.cache, conditions, state.constantTypeCache);
PreVisitor preVisitor(conditions, state.constantTypeCache);
shader->Visit(preVisitor);
for (const std::string& extInst : preVisitor.extInsts)
@ -397,7 +406,7 @@ namespace Nz
state.parameterIds.emplace(param.name, std::move(parameterData));
}
SpirvAstVisitor visitor(*this, state.functionBlocks, &m_context.cache);
SpirvAstVisitor visitor(*this, state.functionBlocks);
for (const auto& statement : func.statements)
statement->Visit(visitor);
@ -419,7 +428,7 @@ namespace Nz
for (std::size_t i = 0; i < ShaderStageTypeCount; ++i)
{
const ShaderAst::DeclareFunctionStatement* statement = m_context.cache.entryFunctions[i];
/*const ShaderAst::DeclareFunctionStatement* statement = m_context.cache.entryFunctions[i];
if (!statement)
continue;
@ -462,7 +471,7 @@ namespace Nz
});
if (stage == ShaderStageType::Fragment)
state.header.Append(SpirvOp::OpExecutionMode, entryFunc.id, SpirvExecutionMode::OriginUpperLeft);
state.header.Append(SpirvOp::OpExecutionMode, entryFunc.id, SpirvExecutionMode::OriginUpperLeft);*/
}
std::vector<UInt32> ret;