Rework shader AST (WIP)

This commit is contained in:
Jérôme Leclercq 2021-03-10 11:18:13 +01:00
parent b320b5b44e
commit fed7370e77
73 changed files with 2721 additions and 4312 deletions

View File

@ -16,6 +16,7 @@
#include <Nazara/Renderer/RenderPipelineLayout.hpp>
#include <Nazara/Renderer/Texture.hpp>
#include <Nazara/Renderer/TextureSampler.hpp>
#include <Nazara/Shader/ShaderNodes.hpp>
#include <Nazara/Shader/ShaderWriter.hpp>
#include <Nazara/Utility/AbstractBuffer.hpp>
#include <memory>
@ -24,7 +25,6 @@
namespace Nz
{
class CommandPool;
class ShaderAst;
class ShaderStage;
class NAZARA_RENDERER_API RenderDevice
@ -39,7 +39,7 @@ namespace Nz
virtual std::shared_ptr<RenderPass> InstantiateRenderPass(std::vector<RenderPass::Attachment> attachments, std::vector<RenderPass::SubpassDescription> subpassDescriptions, std::vector<RenderPass::SubpassDependency> subpassDependencies) = 0;
virtual std::shared_ptr<RenderPipeline> InstantiateRenderPipeline(RenderPipelineInfo pipelineInfo) = 0;
virtual std::shared_ptr<RenderPipelineLayout> InstantiateRenderPipelineLayout(RenderPipelineLayoutInfo pipelineLayoutInfo) = 0;
virtual std::shared_ptr<ShaderStage> InstantiateShaderStage(const ShaderAst& shaderAst, const ShaderWriter::States& states) = 0;
virtual std::shared_ptr<ShaderStage> InstantiateShaderStage(const ShaderAst::StatementPtr& shaderAst, const ShaderWriter::States& states) = 0;
virtual std::shared_ptr<ShaderStage> InstantiateShaderStage(ShaderStageType type, ShaderLanguage lang, const void* source, std::size_t sourceSize) = 0;
std::shared_ptr<ShaderStage> InstantiateShaderStage(ShaderStageType type, ShaderLanguage lang, const std::filesystem::path& sourcePath);
virtual std::shared_ptr<Texture> InstantiateTexture(const TextureInfo& params) = 0;

View File

@ -32,23 +32,25 @@
#include <Nazara/Shader/Config.hpp>
#include <Nazara/Shader/GlslWriter.hpp>
#include <Nazara/Shader/Shader.hpp>
#include <Nazara/Shader/ShaderAst.hpp>
#include <Nazara/Shader/ShaderAstCache.hpp>
#include <Nazara/Shader/ShaderAstCloner.hpp>
#include <Nazara/Shader/ShaderAstExpressionType.hpp>
#include <Nazara/Shader/ShaderAstExpressionVisitor.hpp>
#include <Nazara/Shader/ShaderAstExpressionVisitorExcept.hpp>
#include <Nazara/Shader/ShaderAstOptimizer.hpp>
#include <Nazara/Shader/ShaderAstRecursiveVisitor.hpp>
#include <Nazara/Shader/ShaderAstSerializer.hpp>
#include <Nazara/Shader/ShaderAstStatementVisitor.hpp>
#include <Nazara/Shader/ShaderAstStatementVisitorExcept.hpp>
#include <Nazara/Shader/ShaderAstTypes.hpp>
#include <Nazara/Shader/ShaderAstUtils.hpp>
#include <Nazara/Shader/ShaderAstValidator.hpp>
#include <Nazara/Shader/ShaderAstVisitor.hpp>
#include <Nazara/Shader/ShaderAstVisitorExcept.hpp>
#include <Nazara/Shader/ShaderBuilder.hpp>
#include <Nazara/Shader/ShaderConstantValue.hpp>
#include <Nazara/Shader/ShaderEnums.hpp>
#include <Nazara/Shader/ShaderExpressionType.hpp>
#include <Nazara/Shader/ShaderLangLexer.hpp>
#include <Nazara/Shader/ShaderLangParser.hpp>
#include <Nazara/Shader/ShaderNodes.hpp>
#include <Nazara/Shader/ShaderVariables.hpp>
#include <Nazara/Shader/ShaderVarVisitor.hpp>
#include <Nazara/Shader/ShaderVarVisitorExcept.hpp>
#include <Nazara/Shader/ShaderWriter.hpp>
#include <Nazara/Shader/SpirvAstVisitor.hpp>
#include <Nazara/Shader/SpirvBlock.hpp>
@ -58,6 +60,7 @@
#include <Nazara/Shader/SpirvExpressionStore.hpp>
#include <Nazara/Shader/SpirvPrinter.hpp>
#include <Nazara/Shader/SpirvSection.hpp>
#include <Nazara/Shader/SpirvSectionBase.hpp>
#include <Nazara/Shader/SpirvWriter.hpp>
#endif // NAZARA_GLOBAL_SHADER_HPP

View File

@ -9,9 +9,7 @@
#include <Nazara/Prerequisites.hpp>
#include <Nazara/Shader/Config.hpp>
#include <Nazara/Shader/ShaderAst.hpp>
#include <Nazara/Shader/ShaderVarVisitor.hpp>
#include <Nazara/Shader/ShaderAstVisitor.hpp>
#include <Nazara/Shader/ShaderAstRecursiveVisitor.hpp>
#include <Nazara/Shader/ShaderWriter.hpp>
#include <set>
#include <sstream>
@ -19,7 +17,7 @@
namespace Nz
{
class NAZARA_SHADER_API GlslWriter : public ShaderWriter, public ShaderVarVisitor, public ShaderAstVisitor
class NAZARA_SHADER_API GlslWriter : public ShaderWriter, public ShaderAst::AstRecursiveVisitor
{
public:
struct Environment;
@ -30,7 +28,7 @@ namespace Nz
GlslWriter(GlslWriter&&) = delete;
~GlslWriter() = default;
std::string Generate(const ShaderAst& shader, const States& conditions = {});
std::string Generate(ShaderAst::StatementPtr& shader, const States& conditions = {});
void SetEnv(Environment environment);
@ -46,67 +44,45 @@ namespace Nz
static const char* GetFlipYUniformName();
private:
void Append(ShaderExpressionType type);
void Append(ShaderNodes::BuiltinEntry builtin);
void Append(ShaderNodes::BasicType type);
void Append(ShaderNodes::MemoryLayout layout);
void Append(ShaderAst::ShaderExpressionType type);
void Append(ShaderAst::BuiltinEntry builtin);
void Append(ShaderAst::BasicType type);
void Append(ShaderAst::MemoryLayout layout);
template<typename T> void Append(const T& param);
void AppendCommentSection(const std::string& section);
void AppendField(const std::string& structName, std::size_t* memberIndex, std::size_t remainingMembers);
void AppendFunction(const ShaderAst::Function& func);
void AppendFunctionPrototype(const ShaderAst::Function& func);
void AppendField(std::size_t scopeId, const std::string& structName, const std::string* memberIdentifier, std::size_t remainingMembers);
void AppendLine(const std::string& txt = {});
template<typename T> void DeclareVariables(const ShaderAst& shader, const std::vector<T>& variables, const std::string& keyword = {}, const std::string& section = {});
void EnterScope();
void LeaveScope();
using ShaderVarVisitor::Visit;
using ShaderAstVisitor::Visit;
void Visit(ShaderNodes::ExpressionPtr& expr, bool encloseIfRequired = false);
void Visit(ShaderNodes::AccessMember& node) override;
void Visit(ShaderNodes::AssignOp& node) override;
void Visit(ShaderNodes::Branch& node) override;
void Visit(ShaderNodes::BinaryOp& node) override;
void Visit(ShaderNodes::BuiltinVariable& var) override;
void Visit(ShaderNodes::Cast& node) override;
void Visit(ShaderNodes::ConditionalExpression& node) override;
void Visit(ShaderNodes::ConditionalStatement& node) override;
void Visit(ShaderNodes::Constant& node) override;
void Visit(ShaderNodes::DeclareVariable& node) override;
void Visit(ShaderNodes::Discard& node) override;
void Visit(ShaderNodes::ExpressionStatement& node) override;
void Visit(ShaderNodes::Identifier& node) override;
void Visit(ShaderNodes::InputVariable& var) override;
void Visit(ShaderNodes::IntrinsicCall& node) override;
void Visit(ShaderNodes::LocalVariable& var) override;
void Visit(ShaderNodes::NoOp& node) override;
void Visit(ShaderNodes::ParameterVariable& var) override;
void Visit(ShaderNodes::ReturnStatement& node) override;
void Visit(ShaderNodes::OutputVariable& var) override;
void Visit(ShaderNodes::Sample2D& node) override;
void Visit(ShaderNodes::StatementBlock& node) override;
void Visit(ShaderNodes::SwizzleOp& node) override;
void Visit(ShaderNodes::UniformVariable& var) override;
void Visit(ShaderAst::ExpressionPtr& expr, bool encloseIfRequired = false);
static bool HasExplicitBinding(const ShaderAst& shader);
static bool HasExplicitLocation(const ShaderAst& shader);
void Visit(ShaderAst::AccessMemberExpression& node) override;
void Visit(ShaderAst::AssignExpression& node) override;
void Visit(ShaderAst::BinaryExpression& node) override;
void Visit(ShaderAst::CastExpression& node) override;
void Visit(ShaderAst::ConditionalExpression& node) override;
void Visit(ShaderAst::ConstantExpression& node) override;
void Visit(ShaderAst::IdentifierExpression& node) override;
void Visit(ShaderAst::IntrinsicExpression& node) override;
void Visit(ShaderAst::SwizzleExpression& node) override;
struct Context
{
const ShaderAst* shader = nullptr;
const ShaderAst::Function* currentFunction = nullptr;
const States* states = nullptr;
};
void Visit(ShaderAst::BranchStatement& node) override;
void Visit(ShaderAst::ConditionalStatement& node) override;
void Visit(ShaderAst::DeclareFunctionStatement& node) override;
void Visit(ShaderAst::DeclareVariableStatement& node) override;
void Visit(ShaderAst::DiscardStatement& node) override;
void Visit(ShaderAst::ExpressionStatement& node) override;
void Visit(ShaderAst::MultiStatement& node) override;
void Visit(ShaderAst::NoOpStatement& node) override;
void Visit(ShaderAst::ReturnStatement& node) override;
struct State
{
std::stringstream stream;
unsigned int indentLevel = 0;
};
static bool HasExplicitBinding(ShaderAst::StatementPtr& shader);
static bool HasExplicitLocation(ShaderAst::StatementPtr& shader);
struct State;
Context m_context;
Environment m_environment;
State* m_currentState;
};

View File

@ -3,130 +3,10 @@
// For conditions of distribution and use, see copyright notice in Config.hpp
#include <Nazara/Shader/GlslWriter.hpp>
#include <type_traits>
#include <Nazara/Shader/Debug.hpp>
namespace Nz
{
template<typename T>
void GlslWriter::Append(const T& param)
{
NazaraAssert(m_currentState, "This function should only be called while processing an AST");
m_currentState->stream << param;
}
template<typename T>
void GlslWriter::DeclareVariables(const ShaderAst& shader, const std::vector<T>& variables, const std::string& keyword, const std::string& section)
{
if (!variables.empty())
{
if (!section.empty())
AppendCommentSection(section);
for (const auto& var : variables)
{
if constexpr (std::is_same_v<T, ShaderAst::InputOutput>)
{
if (var.locationIndex)
{
Append("layout(location = ");
Append(*var.locationIndex);
Append(") ");
}
if (!keyword.empty())
{
Append(keyword);
Append(" ");
}
Append(var.type);
Append(" ");
Append(var.name);
AppendLine(";");
}
else if constexpr (std::is_same_v<T, ShaderAst::Uniform>)
{
if (var.bindingIndex || var.memoryLayout)
{
Append("layout(");
bool first = true;
if (var.bindingIndex)
{
if (!first)
Append(", ");
Append("binding = ");
Append(*var.bindingIndex);
first = false;
}
if (var.memoryLayout)
{
if (!first)
Append(", ");
Append(*var.memoryLayout);
first = false;
}
Append(") ");
}
if (!keyword.empty())
{
Append(keyword);
Append(" ");
}
std::visit([&](auto&& arg)
{
using U = std::decay_t<decltype(arg)>;
if constexpr (std::is_same_v<U, ShaderNodes::BasicType>)
{
Append(arg);
Append(" ");
Append(var.name);
}
else if constexpr (std::is_same_v<U, std::string>)
{
const auto& structs = shader.GetStructs();
auto it = std::find_if(structs.begin(), structs.end(), [&](const auto& s) { return s.name == arg; });
if (it == structs.end())
throw std::runtime_error("struct " + arg + " has not been defined");
const auto& s = *it;
AppendLine(var.name + "_interface");
AppendLine("{");
for (const auto& m : s.members)
{
Append("\t");
Append(m.type);
Append(" ");
Append(m.name);
AppendLine(";");
}
Append("} ");
Append(var.name);
}
else
static_assert(AlwaysFalse<U>::value, "non-exhaustive visitor");
}, var.type);
AppendLine(";");
AppendLine();
}
}
AppendLine();
}
}
}
#include <Nazara/Shader/DebugOff.hpp>

View File

@ -1,130 +0,0 @@
// Copyright (C) 2020 Jérôme Leclercq
// This file is part of the "Nazara Engine - Renderer module"
// For conditions of distribution and use, see copyright notice in Config.hpp
#pragma once
#ifndef NAZARA_SHADER_AST_HPP
#define NAZARA_SHADER_AST_HPP
#include <Nazara/Prerequisites.hpp>
#include <Nazara/Shader/ShaderExpressionType.hpp>
#include <Nazara/Shader/ShaderNodes.hpp>
#include <Nazara/Utility/Enums.hpp>
#include <optional>
#include <unordered_map>
#include <vector>
namespace Nz
{
class NAZARA_SHADER_API ShaderAst
{
public:
struct Condition;
struct Function;
struct FunctionParameter;
struct InputOutput;
struct Struct;
struct StructMember;
struct Uniform;
struct VariableBase;
inline ShaderAst(ShaderStageType shaderStage);
ShaderAst(const ShaderAst&) = default;
ShaderAst(ShaderAst&&) noexcept = default;
~ShaderAst() = default;
void AddCondition(std::string name);
void AddFunction(std::string name, ShaderNodes::StatementPtr statement, std::vector<FunctionParameter> parameters = {}, ShaderExpressionType returnType = ShaderNodes::BasicType::Void);
void AddInput(std::string name, ShaderExpressionType type, std::optional<std::size_t> locationIndex = {});
void AddOutput(std::string name, ShaderExpressionType type, std::optional<std::size_t> locationIndex = {});
void AddStruct(std::string name, std::vector<StructMember> members);
void AddUniform(std::string name, ShaderExpressionType type, std::optional<std::size_t> bindingIndex = {}, std::optional<ShaderNodes::MemoryLayout> memoryLayout = {});
inline std::size_t FindConditionByName(const std::string_view& conditionName) const;
inline const Condition& GetCondition(std::size_t i) const;
inline std::size_t GetConditionCount() const;
inline const std::vector<Condition>& GetConditions() const;
inline const Function& GetFunction(std::size_t i) const;
inline std::size_t GetFunctionCount() const;
inline const std::vector<Function>& GetFunctions() const;
inline const InputOutput& GetInput(std::size_t i) const;
inline std::size_t GetInputCount() const;
inline const std::vector<InputOutput>& GetInputs() const;
inline const InputOutput& GetOutput(std::size_t i) const;
inline std::size_t GetOutputCount() const;
inline const std::vector<InputOutput>& GetOutputs() const;
inline ShaderStageType GetStage() const;
inline const Struct& GetStruct(std::size_t i) const;
inline std::size_t GetStructCount() const;
inline const std::vector<Struct>& GetStructs() const;
inline const Uniform& GetUniform(std::size_t i) const;
inline std::size_t GetUniformCount() const;
inline const std::vector<Uniform>& GetUniforms() const;
ShaderAst& operator=(const ShaderAst&) = default;
ShaderAst& operator=(ShaderAst&&) noexcept = default;
struct Condition
{
std::string name;
};
struct VariableBase
{
std::string name;
ShaderExpressionType type;
};
struct FunctionParameter : VariableBase
{
};
struct Function
{
std::string name;
std::vector<FunctionParameter> parameters;
ShaderExpressionType returnType;
ShaderNodes::StatementPtr statement;
};
struct InputOutput : VariableBase
{
std::optional<std::size_t> locationIndex;
};
struct Uniform : VariableBase
{
std::optional<std::size_t> bindingIndex;
std::optional<ShaderNodes::MemoryLayout> memoryLayout;
};
struct Struct
{
std::string name;
std::vector<StructMember> members;
};
struct StructMember
{
std::string name;
ShaderExpressionType type;
};
static constexpr std::size_t InvalidCondition = std::numeric_limits<std::size_t>::max();
private:
std::vector<Condition> m_conditions;
std::vector<Function> m_functions;
std::vector<InputOutput> m_inputs;
std::vector<InputOutput> m_outputs;
std::vector<Struct> m_structs;
std::vector<Uniform> m_uniforms;
ShaderStageType m_stage;
};
}
#include <Nazara/Shader/ShaderAst.inl>
#endif // NAZARA_SHADER_AST_HPP

View File

@ -1,128 +0,0 @@
// Copyright (C) 2020 Jérôme Leclercq
// This file is part of the "Nazara Engine - Shader generator"
// For conditions of distribution and use, see copyright notice in Config.hpp
#include <Nazara/Shader/ShaderAst.hpp>
#include <Nazara/Shader/Debug.hpp>
namespace Nz
{
inline ShaderAst::ShaderAst(ShaderStageType shaderStage) :
m_stage(shaderStage)
{
}
inline std::size_t ShaderAst::FindConditionByName(const std::string_view& conditionName) const
{
for (std::size_t i = 0; i < m_conditions.size(); ++i)
{
if (m_conditions[i].name == conditionName)
return i;
}
return InvalidCondition;
}
inline auto Nz::ShaderAst::GetCondition(std::size_t i) const -> const Condition&
{
assert(i < m_functions.size());
return m_conditions[i];
}
inline std::size_t ShaderAst::GetConditionCount() const
{
return m_conditions.size();
}
inline auto ShaderAst::GetConditions() const -> const std::vector<Condition>&
{
return m_conditions;
}
inline auto ShaderAst::GetFunction(std::size_t i) const -> const Function&
{
assert(i < m_functions.size());
return m_functions[i];
}
inline std::size_t ShaderAst::GetFunctionCount() const
{
return m_functions.size();
}
inline auto ShaderAst::GetFunctions() const -> const std::vector<Function>&
{
return m_functions;
}
inline auto ShaderAst::GetInput(std::size_t i) const -> const InputOutput&
{
assert(i < m_inputs.size());
return m_inputs[i];
}
inline std::size_t ShaderAst::GetInputCount() const
{
return m_inputs.size();
}
inline auto ShaderAst::GetInputs() const -> const std::vector<InputOutput>&
{
return m_inputs;
}
inline auto ShaderAst::GetOutput(std::size_t i) const -> const InputOutput&
{
assert(i < m_outputs.size());
return m_outputs[i];
}
inline std::size_t ShaderAst::GetOutputCount() const
{
return m_outputs.size();
}
inline auto ShaderAst::GetOutputs() const -> const std::vector<InputOutput>&
{
return m_outputs;
}
inline ShaderStageType ShaderAst::GetStage() const
{
return m_stage;
}
inline auto ShaderAst::GetStruct(std::size_t i) const -> const Struct&
{
assert(i < m_structs.size());
return m_structs[i];
}
inline std::size_t ShaderAst::GetStructCount() const
{
return m_structs.size();
}
inline auto ShaderAst::GetStructs() const -> const std::vector<Struct>&
{
return m_structs;
}
inline auto ShaderAst::GetUniform(std::size_t i) const -> const Uniform&
{
assert(i < m_uniforms.size());
return m_uniforms[i];
}
inline std::size_t ShaderAst::GetUniformCount() const
{
return m_uniforms.size();
}
inline auto ShaderAst::GetUniforms() const -> const std::vector<Uniform>&
{
return m_uniforms;
}
}
#include <Nazara/Shader/DebugOff.hpp>

View File

@ -0,0 +1,48 @@
// Copyright (C) 2020 Jérôme Leclercq
// This file is part of the "Nazara Engine - Shader generator"
// For conditions of distribution and use, see copyright notice in Config.hpp
#pragma once
#ifndef NAZARA_SHADERASTCACHE_HPP
#define NAZARA_SHADERASTCACHE_HPP
#include <Nazara/Prerequisites.hpp>
#include <Nazara/Shader/Config.hpp>
#include <Nazara/Shader/ShaderAstExpressionType.hpp>
#include <Nazara/Utility/Enums.hpp>
namespace Nz::ShaderAst
{
struct AstCache
{
struct Variable
{
ShaderExpressionType type;
};
struct Identifier
{
std::string name;
std::variant<Variable, StructDescription> value;
};
struct Scope
{
std::optional<std::size_t> parentScopeIndex;
std::vector<Identifier> identifiers;
};
inline const Identifier* FindIdentifier(std::size_t startingScopeId, const std::string& identifierName) const;
inline std::size_t GetScopeId(const Node* node) const;
ShaderStageType stageType = ShaderStageType::Undefined;
std::unordered_map<const Expression*, ShaderExpressionType> nodeExpressionType;
std::unordered_map<const Node*, std::size_t> scopeIdByNode;
std::vector<Scope> scopes;
};
}
#include <Nazara/Shader/ShaderAstCache.inl>
#endif

View File

@ -0,0 +1,37 @@
// Copyright (C) 2020 Jérôme Leclercq
// This file is part of the "Nazara Engine - Shader generator"
// For conditions of distribution and use, see copyright notice in Config.hpp
#include <Nazara/Shader/ShaderAstCache.hpp>
#include <Nazara/Shader/Debug.hpp>
namespace Nz::ShaderAst
{
inline auto AstCache::FindIdentifier(std::size_t startingScopeId, const std::string& identifierName) const -> const Identifier*
{
assert(startingScopeId < scopes.size());
std::optional<std::size_t> scopeId = startingScopeId;
do
{
const auto& scope = scopes[*scopeId];
auto it = std::find_if(scope.identifiers.rbegin(), scope.identifiers.rend(), [&](const auto& identifier) { return identifier.name == identifierName; });
if (it != scope.identifiers.rend())
return &*it;
scopeId = scope.parentScopeIndex;
} while (scopeId);
return nullptr;
}
inline std::size_t AstCache::GetScopeId(const Node* node) const
{
auto it = scopeIdByNode.find(node);
assert(it == scopeIdByNode.end());
return it->second;
}
}
#include <Nazara/Shader/DebugOff.hpp>

View File

@ -9,70 +9,62 @@
#include <Nazara/Prerequisites.hpp>
#include <Nazara/Shader/Config.hpp>
#include <Nazara/Shader/ShaderAstVisitor.hpp>
#include <Nazara/Shader/ShaderVarVisitor.hpp>
#include <Nazara/Shader/ShaderAstExpressionVisitor.hpp>
#include <Nazara/Shader/ShaderAstStatementVisitor.hpp>
#include <vector>
namespace Nz
namespace Nz::ShaderAst
{
class NAZARA_SHADER_API ShaderAstCloner : public ShaderAstVisitor, public ShaderVarVisitor
class NAZARA_SHADER_API AstCloner : public AstExpressionVisitor, public AstStatementVisitor
{
public:
ShaderAstCloner() = default;
ShaderAstCloner(const ShaderAstCloner&) = delete;
ShaderAstCloner(ShaderAstCloner&&) = delete;
~ShaderAstCloner() = default;
AstCloner() = default;
AstCloner(const AstCloner&) = delete;
AstCloner(AstCloner&&) = delete;
~AstCloner() = default;
ShaderNodes::StatementPtr Clone(const ShaderNodes::StatementPtr& statement);
ExpressionPtr Clone(ExpressionPtr& statement);
StatementPtr Clone(StatementPtr& statement);
ShaderAstCloner& operator=(const ShaderAstCloner&) = delete;
ShaderAstCloner& operator=(ShaderAstCloner&&) = delete;
AstCloner& operator=(const AstCloner&) = delete;
AstCloner& operator=(AstCloner&&) = delete;
protected:
ShaderNodes::ExpressionPtr CloneExpression(const ShaderNodes::ExpressionPtr& expr);
ShaderNodes::StatementPtr CloneStatement(const ShaderNodes::StatementPtr& statement);
ShaderNodes::VariablePtr CloneVariable(const ShaderNodes::VariablePtr& statement);
ExpressionPtr CloneExpression(ExpressionPtr& expr);
StatementPtr CloneStatement(StatementPtr& statement);
using ShaderAstVisitor::Visit;
void Visit(ShaderNodes::AccessMember& node) override;
void Visit(ShaderNodes::AssignOp& node) override;
void Visit(ShaderNodes::BinaryOp& node) override;
void Visit(ShaderNodes::Branch& node) override;
void Visit(ShaderNodes::Cast& node) override;
void Visit(ShaderNodes::ConditionalExpression& node) override;
void Visit(ShaderNodes::ConditionalStatement& node) override;
void Visit(ShaderNodes::Constant& node) override;
void Visit(ShaderNodes::DeclareVariable& node) override;
void Visit(ShaderNodes::Discard& node) override;
void Visit(ShaderNodes::ExpressionStatement& node) override;
void Visit(ShaderNodes::Identifier& node) override;
void Visit(ShaderNodes::IntrinsicCall& node) override;
void Visit(ShaderNodes::NoOp& node) override;
void Visit(ShaderNodes::ReturnStatement& node) override;
void Visit(ShaderNodes::Sample2D& node) override;
void Visit(ShaderNodes::StatementBlock& node) override;
void Visit(ShaderNodes::SwizzleOp& node) override;
using AstExpressionVisitor::Visit;
using AstStatementVisitor::Visit;
using ShaderVarVisitor::Visit;
void Visit(ShaderNodes::BuiltinVariable& var) override;
void Visit(ShaderNodes::InputVariable& var) override;
void Visit(ShaderNodes::LocalVariable& var) override;
void Visit(ShaderNodes::OutputVariable& var) override;
void Visit(ShaderNodes::ParameterVariable& var) override;
void Visit(ShaderNodes::UniformVariable& var) override;
void Visit(AccessMemberExpression& node) override;
void Visit(AssignExpression& node) override;
void Visit(BinaryExpression& node) override;
void Visit(CastExpression& node) override;
void Visit(ConditionalExpression& node) override;
void Visit(ConstantExpression& node) override;
void Visit(IdentifierExpression& node) override;
void Visit(IntrinsicExpression& node) override;
void Visit(SwizzleExpression& node) override;
void Visit(BranchStatement& node) override;
void Visit(ConditionalStatement& node) override;
void Visit(DeclareFunctionStatement& node) override;
void Visit(DeclareStructStatement& node) override;
void Visit(DeclareVariableStatement& node) override;
void Visit(DiscardStatement& node) override;
void Visit(ExpressionStatement& node) override;
void Visit(MultiStatement& node) override;
void Visit(NoOpStatement& node) override;
void Visit(ReturnStatement& node) override;
void PushExpression(ShaderNodes::ExpressionPtr expression);
void PushStatement(ShaderNodes::StatementPtr statement);
void PushVariable(ShaderNodes::VariablePtr variable);
void PushExpression(ExpressionPtr expression);
void PushStatement(StatementPtr statement);
ShaderNodes::ExpressionPtr PopExpression();
ShaderNodes::StatementPtr PopStatement();
ShaderNodes::VariablePtr PopVariable();
ExpressionPtr PopExpression();
StatementPtr PopStatement();
private:
std::vector<ShaderNodes::ExpressionPtr> m_expressionStack;
std::vector<ShaderNodes::StatementPtr> m_statementStack;
std::vector<ShaderNodes::VariablePtr> m_variableStack;
std::vector<ExpressionPtr> m_expressionStack;
std::vector<StatementPtr> m_statementStack;
};
}

View File

@ -0,0 +1,57 @@
// Copyright (C) 2020 Jérôme Leclercq
// This file is part of the "Nazara Engine - Shader generator"
// For conditions of distribution and use, see copyright notice in Config.hpp
#pragma once
#ifndef NAZARA_SHADERASTEXPRESSIONTYPE_HPP
#define NAZARA_SHADERASTEXPRESSIONTYPE_HPP
#include <Nazara/Prerequisites.hpp>
#include <Nazara/Shader/Config.hpp>
#include <Nazara/Shader/ShaderAstExpressionVisitor.hpp>
#include <Nazara/Shader/ShaderAstTypes.hpp>
#include <vector>
namespace Nz::ShaderAst
{
struct AstCache;
class NAZARA_SHADER_API ExpressionTypeVisitor : public AstExpressionVisitor
{
public:
ExpressionTypeVisitor() = default;
ExpressionTypeVisitor(const ExpressionTypeVisitor&) = delete;
ExpressionTypeVisitor(ExpressionTypeVisitor&&) = delete;
~ExpressionTypeVisitor() = default;
ShaderExpressionType GetExpressionType(Expression& expression, AstCache* cache);
ExpressionTypeVisitor& operator=(const ExpressionTypeVisitor&) = delete;
ExpressionTypeVisitor& operator=(ExpressionTypeVisitor&&) = delete;
private:
ShaderExpressionType GetExpressionTypeInternal(Expression& expression);
void Visit(Expression& expression);
void Visit(AccessMemberExpression& node) override;
void Visit(AssignExpression& node) override;
void Visit(BinaryExpression& node) override;
void Visit(CastExpression& node) override;
void Visit(ConditionalExpression& node) override;
void Visit(ConstantExpression& node) override;
void Visit(IdentifierExpression& node) override;
void Visit(IntrinsicExpression& node) override;
void Visit(SwizzleExpression& node) override;
AstCache* m_cache;
std::optional<ShaderExpressionType> m_lastExpressionType;
};
inline ShaderExpressionType GetExpressionType(Expression& expression, AstCache* cache = nullptr);
}
#include <Nazara/Shader/ShaderAstExpressionType.inl>
#endif

View File

@ -0,0 +1,17 @@
// Copyright (C) 2020 Jérôme Leclercq
// This file is part of the "Nazara Engine - Shader generator"
// For conditions of distribution and use, see copyright notice in Config.hpp
#include <Nazara/Shader/ShaderAstExpressionType.hpp>
#include <Nazara/Shader/Debug.hpp>
namespace Nz::ShaderAst
{
inline ShaderExpressionType GetExpressionType(Expression& expression, AstCache* cache)
{
ExpressionTypeVisitor visitor;
return visitor.GetExpressionType(expression, cache);
}
}
#include <Nazara/Shader/DebugOff.hpp>

View File

@ -0,0 +1,32 @@
// Copyright (C) 2020 Jérôme Leclercq
// This file is part of the "Nazara Engine - Shader generator"
// For conditions of distribution and use, see copyright notice in Config.hpp
#pragma once
#ifndef NAZARA_SHADERASTEXPRESSIONVISITOR_HPP
#define NAZARA_SHADERASTEXPRESSIONVISITOR_HPP
#include <Nazara/Prerequisites.hpp>
#include <Nazara/Shader/Config.hpp>
#include <Nazara/Shader/ShaderNodes.hpp>
namespace Nz::ShaderAst
{
class NAZARA_SHADER_API AstExpressionVisitor
{
public:
AstExpressionVisitor() = default;
AstExpressionVisitor(const AstExpressionVisitor&) = delete;
AstExpressionVisitor(AstExpressionVisitor&&) = delete;
virtual ~AstExpressionVisitor();
#define NAZARA_SHADERAST_EXPRESSION(NodeType) virtual void Visit(NodeType& node) = 0;
#include <Nazara/Shader/ShaderAstNodes.hpp>
AstExpressionVisitor& operator=(const AstExpressionVisitor&) = delete;
AstExpressionVisitor& operator=(AstExpressionVisitor&&) = delete;
};
}
#endif

View File

@ -0,0 +1,26 @@
// Copyright (C) 2020 Jérôme Leclercq
// This file is part of the "Nazara Engine - Shader generator"
// For conditions of distribution and use, see copyright notice in Config.hpp
#pragma once
#ifndef NAZARA_SHADERASTEXPRESSIONVISITOREXCEPT_HPP
#define NAZARA_SHADERASTEXPRESSIONVISITOREXCEPT_HPP
#include <Nazara/Prerequisites.hpp>
#include <Nazara/Shader/Config.hpp>
#include <Nazara/Shader/ShaderAstExpressionVisitor.hpp>
namespace Nz::ShaderAst
{
class NAZARA_SHADER_API ExpressionVisitorExcept : public AstExpressionVisitor
{
public:
using AstExpressionVisitor::Visit;
#define NAZARA_SHADERAST_EXPRESSION(Node) void Visit(ShaderAst::Node& node) override;
#include <Nazara/Shader/ShaderAstNodes.hpp>
};
}
#endif

View File

@ -0,0 +1,53 @@
// Copyright (C) 2020 Jérôme Leclercq
// This file is part of the "Nazara Engine - Renderer module"
// For conditions of distribution and use, see copyright notice in Config.hpp
#if !defined(NAZARA_SHADERAST_NODE) && !defined(NAZARA_SHADERAST_EXPRESSION) && !defined(NAZARA_SHADERAST_STATEMENT)
#error You must define NAZARA_SHADERAST_NODE or NAZARA_SHADERAST_EXPRESSION or NAZARA_SHADERAST_STATEMENT before including this file
#endif
#ifndef NAZARA_SHADERAST_NODE
#define NAZARA_SHADERAST_NODE(X)
#endif
#ifndef NAZARA_SHADERAST_NODE_LAST
#define NAZARA_SHADERAST_NODE_LAST(X)
#endif
#ifndef NAZARA_SHADERAST_EXPRESSION
#define NAZARA_SHADERAST_EXPRESSION(X) NAZARA_SHADERAST_NODE(X)
#endif
#ifndef NAZARA_SHADERAST_STATEMENT
#define NAZARA_SHADERAST_STATEMENT(X) NAZARA_SHADERAST_NODE(X)
#endif
#ifndef NAZARA_SHADERAST_STATEMENT_LAST
#define NAZARA_SHADERAST_STATEMENT_LAST(X) NAZARA_SHADERAST_STATEMENT(X)
#endif
NAZARA_SHADERAST_EXPRESSION(AccessMemberExpression)
NAZARA_SHADERAST_EXPRESSION(AssignExpression)
NAZARA_SHADERAST_EXPRESSION(BinaryExpression)
NAZARA_SHADERAST_EXPRESSION(CastExpression)
NAZARA_SHADERAST_EXPRESSION(ConditionalExpression)
NAZARA_SHADERAST_EXPRESSION(ConstantExpression)
NAZARA_SHADERAST_EXPRESSION(IdentifierExpression)
NAZARA_SHADERAST_EXPRESSION(IntrinsicExpression)
NAZARA_SHADERAST_EXPRESSION(SwizzleExpression)
NAZARA_SHADERAST_STATEMENT(BranchStatement)
NAZARA_SHADERAST_STATEMENT(ConditionalStatement)
NAZARA_SHADERAST_STATEMENT(DeclareFunctionStatement)
NAZARA_SHADERAST_STATEMENT(DeclareStructStatement)
NAZARA_SHADERAST_STATEMENT(DeclareVariableStatement)
NAZARA_SHADERAST_STATEMENT(DiscardStatement)
NAZARA_SHADERAST_STATEMENT(ExpressionStatement)
NAZARA_SHADERAST_STATEMENT(MultiStatement)
NAZARA_SHADERAST_STATEMENT(NoOpStatement)
NAZARA_SHADERAST_STATEMENT_LAST(ReturnStatement)
#undef NAZARA_SHADERAST_EXPRESSION
#undef NAZARA_SHADERAST_NODE
#undef NAZARA_SHADERAST_NODE_LAST
#undef NAZARA_SHADERAST_STATEMENT
#undef NAZARA_SHADERAST_STATEMENT_LAST

View File

@ -12,35 +12,32 @@
#include <Nazara/Shader/ShaderAstCloner.hpp>
#include <vector>
namespace Nz
namespace Nz::ShaderAst
{
class ShaderAst;
class NAZARA_SHADER_API ShaderAstOptimizer : public ShaderAstCloner
class NAZARA_SHADER_API AstOptimizer : public AstCloner
{
public:
ShaderAstOptimizer() = default;
ShaderAstOptimizer(const ShaderAstOptimizer&) = delete;
ShaderAstOptimizer(ShaderAstOptimizer&&) = delete;
~ShaderAstOptimizer() = default;
AstOptimizer() = default;
AstOptimizer(const AstOptimizer&) = delete;
AstOptimizer(AstOptimizer&&) = delete;
~AstOptimizer() = default;
ShaderNodes::StatementPtr Optimise(const ShaderNodes::StatementPtr& statement);
ShaderNodes::StatementPtr Optimise(const ShaderNodes::StatementPtr& statement, const ShaderAst& shader, UInt64 enabledConditions);
StatementPtr Optimise(StatementPtr& statement);
StatementPtr Optimise(StatementPtr& statement, UInt64 enabledConditions);
ShaderAstOptimizer& operator=(const ShaderAstOptimizer&) = delete;
ShaderAstOptimizer& operator=(ShaderAstOptimizer&&) = delete;
AstOptimizer& operator=(const AstOptimizer&) = delete;
AstOptimizer& operator=(AstOptimizer&&) = delete;
protected:
using ShaderAstCloner::Visit;
void Visit(ShaderNodes::BinaryOp& node) override;
void Visit(ShaderNodes::Branch& node) override;
void Visit(ShaderNodes::ConditionalExpression& node) override;
void Visit(ShaderNodes::ConditionalStatement& node) override;
using AstCloner::Visit;
void Visit(BinaryExpression& node) override;
void Visit(ConditionalExpression& node) override;
void Visit(BranchStatement& node) override;
void Visit(ConditionalStatement& node) override;
template<ShaderNodes::BinaryType Type> void PropagateConstant(const std::shared_ptr<ShaderNodes::Constant>& lhs, const std::shared_ptr<ShaderNodes::Constant>& rhs);
template<BinaryType Type> void PropagateConstant(std::unique_ptr<ConstantExpression>&& lhs, std::unique_ptr<ConstantExpression>&& rhs);
private:
const ShaderAst* m_shaderAst;
UInt64 m_enabledConditions;
};
}

View File

@ -9,36 +9,37 @@
#include <Nazara/Prerequisites.hpp>
#include <Nazara/Shader/Config.hpp>
#include <Nazara/Shader/ShaderAstVisitor.hpp>
#include <Nazara/Shader/ShaderAstExpressionVisitor.hpp>
#include <Nazara/Shader/ShaderAstStatementVisitor.hpp>
namespace Nz
namespace Nz::ShaderAst
{
class NAZARA_SHADER_API ShaderAstRecursiveVisitor : public ShaderAstVisitor
class NAZARA_SHADER_API AstRecursiveVisitor : public AstExpressionVisitor, public AstStatementVisitor
{
public:
ShaderAstRecursiveVisitor() = default;
~ShaderAstRecursiveVisitor() = default;
AstRecursiveVisitor() = default;
~AstRecursiveVisitor() = default;
using ShaderAstVisitor::Visit;
void Visit(AccessMemberExpression& node) override;
void Visit(AssignExpression& node) override;
void Visit(BinaryExpression& node) override;
void Visit(CastExpression& node) override;
void Visit(ConditionalExpression& node) override;
void Visit(ConstantExpression& node) override;
void Visit(IdentifierExpression& node) override;
void Visit(IntrinsicExpression& node) override;
void Visit(SwizzleExpression& node) override;
void Visit(ShaderNodes::AccessMember& node) override;
void Visit(ShaderNodes::AssignOp& node) override;
void Visit(ShaderNodes::BinaryOp& node) override;
void Visit(ShaderNodes::Branch& node) override;
void Visit(ShaderNodes::Cast& node) override;
void Visit(ShaderNodes::ConditionalExpression& node) override;
void Visit(ShaderNodes::ConditionalStatement& node) override;
void Visit(ShaderNodes::Constant& node) override;
void Visit(ShaderNodes::DeclareVariable& node) override;
void Visit(ShaderNodes::Discard& node) override;
void Visit(ShaderNodes::ExpressionStatement& node) override;
void Visit(ShaderNodes::Identifier& node) override;
void Visit(ShaderNodes::IntrinsicCall& node) override;
void Visit(ShaderNodes::NoOp& node) override;
void Visit(ShaderNodes::ReturnStatement& node) override;
void Visit(ShaderNodes::Sample2D& node) override;
void Visit(ShaderNodes::StatementBlock& node) override;
void Visit(ShaderNodes::SwizzleOp& node) override;
void Visit(BranchStatement& node) override;
void Visit(ConditionalStatement& node) override;
void Visit(DeclareFunctionStatement& node) override;
void Visit(DeclareStructStatement& node) override;
void Visit(DeclareVariableStatement& node) override;
void Visit(DiscardStatement& node) override;
void Visit(ExpressionStatement& node) override;
void Visit(MultiStatement& node) override;
void Visit(NoOpStatement& node) override;
void Visit(ReturnStatement& node) override;
};
}

View File

@ -11,40 +11,38 @@
#include <Nazara/Core/ByteArray.hpp>
#include <Nazara/Core/ByteStream.hpp>
#include <Nazara/Shader/Config.hpp>
#include <Nazara/Shader/ShaderAst.hpp>
#include <Nazara/Shader/ShaderNodes.hpp>
#include <Nazara/Shader/ShaderVariables.hpp>
namespace Nz
namespace Nz::ShaderAst
{
class NAZARA_SHADER_API ShaderAstSerializerBase
class NAZARA_SHADER_API AstSerializerBase
{
public:
ShaderAstSerializerBase() = default;
ShaderAstSerializerBase(const ShaderAstSerializerBase&) = delete;
ShaderAstSerializerBase(ShaderAstSerializerBase&&) = delete;
~ShaderAstSerializerBase() = default;
AstSerializerBase() = default;
AstSerializerBase(const AstSerializerBase&) = delete;
AstSerializerBase(AstSerializerBase&&) = delete;
~AstSerializerBase() = default;
void Serialize(ShaderNodes::AccessMember& node);
void Serialize(ShaderNodes::AssignOp& node);
void Serialize(ShaderNodes::BinaryOp& node);
void Serialize(ShaderNodes::BuiltinVariable& var);
void Serialize(ShaderNodes::Branch& node);
void Serialize(ShaderNodes::Cast& node);
void Serialize(ShaderNodes::ConditionalExpression& node);
void Serialize(ShaderNodes::ConditionalStatement& node);
void Serialize(ShaderNodes::Constant& node);
void Serialize(ShaderNodes::DeclareVariable& node);
void Serialize(ShaderNodes::Discard& node);
void Serialize(ShaderNodes::ExpressionStatement& node);
void Serialize(ShaderNodes::Identifier& node);
void Serialize(ShaderNodes::IntrinsicCall& node);
void Serialize(ShaderNodes::NamedVariable& var);
void Serialize(ShaderNodes::NoOp& node);
void Serialize(ShaderNodes::ReturnStatement& node);
void Serialize(ShaderNodes::Sample2D& node);
void Serialize(ShaderNodes::StatementBlock& node);
void Serialize(ShaderNodes::SwizzleOp& node);
void Serialize(AccessMemberExpression& node);
void Serialize(AssignExpression& node);
void Serialize(BinaryExpression& node);
void Serialize(CastExpression& node);
void Serialize(ConditionalExpression& node);
void Serialize(ConstantExpression& node);
void Serialize(IdentifierExpression& node);
void Serialize(IntrinsicExpression& node);
void Serialize(SwizzleExpression& node);
void Serialize(BranchStatement& node);
void Serialize(ConditionalStatement& node);
void Serialize(DeclareFunctionStatement& node);
void Serialize(DeclareStructStatement& node);
void Serialize(DeclareVariableStatement& node);
void Serialize(DiscardStatement& node);
void Serialize(ExpressionStatement& node);
void Serialize(MultiStatement& node);
void Serialize(NoOpStatement& node);
void Serialize(ReturnStatement& node);
protected:
template<typename T> void Container(T& container);
@ -54,8 +52,8 @@ namespace Nz
virtual bool IsWriting() const = 0;
virtual void Node(ShaderNodes::NodePtr& node) = 0;
template<typename T> void Node(std::shared_ptr<T>& node);
virtual void Node(ExpressionPtr& node) = 0;
virtual void Node(StatementPtr& node) = 0;
virtual void Type(ShaderExpressionType& type) = 0;
@ -74,23 +72,20 @@ namespace Nz
virtual void Value(UInt32& val) = 0;
virtual void Value(UInt64& val) = 0;
inline void SizeT(std::size_t& val);
virtual void Variable(ShaderNodes::VariablePtr& var) = 0;
template<typename T> void Variable(std::shared_ptr<T>& var);
};
class NAZARA_SHADER_API ShaderAstSerializer final : public ShaderAstSerializerBase
class NAZARA_SHADER_API ShaderAstSerializer final : public AstSerializerBase
{
public:
inline ShaderAstSerializer(ByteStream& stream);
~ShaderAstSerializer() = default;
void Serialize(const ShaderAst& shader);
void Serialize(StatementPtr& shader);
private:
bool IsWriting() const override;
void Node(const ShaderNodes::NodePtr& node);
void Node(ShaderNodes::NodePtr& node) override;
void Node(ExpressionPtr& node) override;
void Node(StatementPtr& node) override;
void Type(ShaderExpressionType& type) override;
void Value(bool& val) override;
void Value(float& val) override;
@ -106,22 +101,22 @@ namespace Nz
void Value(UInt16& val) override;
void Value(UInt32& val) override;
void Value(UInt64& val) override;
void Variable(ShaderNodes::VariablePtr& var) override;
ByteStream& m_stream;
};
class NAZARA_SHADER_API ShaderAstUnserializer final : public ShaderAstSerializerBase
class NAZARA_SHADER_API ShaderAstUnserializer final : public AstSerializerBase
{
public:
ShaderAstUnserializer(ByteStream& stream);
~ShaderAstUnserializer() = default;
ShaderAst Unserialize();
StatementPtr Unserialize();
private:
bool IsWriting() const override;
void Node(ShaderNodes::NodePtr& node) override;
void Node(ExpressionPtr& node) override;
void Node(StatementPtr& node) override;
void Type(ShaderExpressionType& type) override;
void Value(bool& val) override;
void Value(float& val) override;
@ -137,14 +132,13 @@ namespace Nz
void Value(UInt16& val) override;
void Value(UInt32& val) override;
void Value(UInt64& val) override;
void Variable(ShaderNodes::VariablePtr& var) override;
ByteStream& m_stream;
};
NAZARA_SHADER_API ByteArray SerializeShader(const ShaderAst& shader);
inline ShaderAst UnserializeShader(const void* data, std::size_t size);
NAZARA_SHADER_API ShaderAst UnserializeShader(ByteStream& stream);
NAZARA_SHADER_API ByteArray SerializeShader(StatementPtr& shader);
inline StatementPtr UnserializeShader(const void* data, std::size_t size);
NAZARA_SHADER_API StatementPtr UnserializeShader(ByteStream& stream);
}
#include <Nazara/Shader/ShaderAstSerializer.inl>

View File

@ -5,10 +5,10 @@
#include <Nazara/Shader/ShaderAstSerializer.hpp>
#include <Nazara/Shader/Debug.hpp>
namespace Nz
namespace Nz::ShaderAst
{
template<typename T>
void ShaderAstSerializerBase::Container(T& container)
void AstSerializerBase::Container(T& container)
{
bool isWriting = IsWriting();
@ -23,7 +23,7 @@ namespace Nz
template<typename T>
void ShaderAstSerializerBase::Enum(T& enumVal)
void AstSerializerBase::Enum(T& enumVal)
{
bool isWriting = IsWriting();
@ -37,7 +37,7 @@ namespace Nz
}
template<typename T>
void ShaderAstSerializerBase::OptEnum(std::optional<T>& optVal)
void AstSerializerBase::OptEnum(std::optional<T>& optVal)
{
bool isWriting = IsWriting();
@ -55,7 +55,7 @@ namespace Nz
}
template<typename T>
void ShaderAstSerializerBase::OptVal(std::optional<T>& optVal)
void AstSerializerBase::OptVal(std::optional<T>& optVal)
{
bool isWriting = IsWriting();
@ -77,21 +77,7 @@ namespace Nz
}
}
template<typename T>
void ShaderAstSerializerBase::Node(std::shared_ptr<T>& node)
{
bool isWriting = IsWriting();
ShaderNodes::NodePtr value;
if (isWriting)
value = node;
Node(value);
if (!isWriting)
node = std::static_pointer_cast<T>(value);
}
inline void ShaderAstSerializerBase::SizeT(std::size_t& val)
inline void AstSerializerBase::SizeT(std::size_t& val)
{
bool isWriting = IsWriting();
@ -105,20 +91,6 @@ namespace Nz
val = static_cast<std::size_t>(fixedVal);
}
template<typename T>
void ShaderAstSerializerBase::Variable(std::shared_ptr<T>& var)
{
bool isWriting = IsWriting();
ShaderNodes::VariablePtr value;
if (isWriting)
value = var;
Variable(value);
if (!isWriting)
var = std::static_pointer_cast<T>(value);
}
inline ShaderAstSerializer::ShaderAstSerializer(ByteStream& stream) :
m_stream(stream)
{
@ -129,7 +101,7 @@ namespace Nz
{
}
inline ShaderAst UnserializeShader(const void* data, std::size_t size)
inline StatementPtr UnserializeShader(const void* data, std::size_t size)
{
ByteStream byteStream(data, size);
return UnserializeShader(byteStream);

View File

@ -0,0 +1,32 @@
// Copyright (C) 2020 Jérôme Leclercq
// This file is part of the "Nazara Engine - Shader generator"
// For conditions of distribution and use, see copyright notice in Config.hpp
#pragma once
#ifndef NAZARA_SHADERASTSTATEMENTVISITOR_HPP
#define NAZARA_SHADERASTSTATEMENTVISITOR_HPP
#include <Nazara/Prerequisites.hpp>
#include <Nazara/Shader/Config.hpp>
#include <Nazara/Shader/ShaderNodes.hpp>
namespace Nz::ShaderAst
{
class NAZARA_SHADER_API AstStatementVisitor
{
public:
AstStatementVisitor() = default;
AstStatementVisitor(const AstStatementVisitor&) = delete;
AstStatementVisitor(AstStatementVisitor&&) = delete;
virtual ~AstStatementVisitor();
#define NAZARA_SHADERAST_STATEMENT(NodeType) virtual void Visit(ShaderAst::NodeType& node) = 0;
#include <Nazara/Shader/ShaderAstNodes.hpp>
AstStatementVisitor& operator=(const AstStatementVisitor&) = delete;
AstStatementVisitor& operator=(AstStatementVisitor&&) = delete;
};
}
#endif

View File

@ -0,0 +1,26 @@
// Copyright (C) 2020 Jérôme Leclercq
// This file is part of the "Nazara Engine - Shader generator"
// For conditions of distribution and use, see copyright notice in Config.hpp
#pragma once
#ifndef NAZARA_SHADERASTSTATEMENTVISITOREXCEPT_HPP
#define NAZARA_SHADERASTSTATEMENTVISITOREXCEPT_HPP
#include <Nazara/Prerequisites.hpp>
#include <Nazara/Shader/Config.hpp>
#include <Nazara/Shader/ShaderAstStatementVisitor.hpp>
namespace Nz::ShaderAst
{
class NAZARA_SHADER_API StatementVisitorExcept : public AstStatementVisitor
{
public:
using AstStatementVisitor::Visit;
#define NAZARA_SHADERAST_STATEMENT(Node) void Visit(ShaderAst::Node& node) override;
#include <Nazara/Shader/ShaderAstNodes.hpp>
};
}
#endif

View File

@ -4,17 +4,30 @@
#pragma once
#ifndef NAZARA_SHADER_EXPRESSIONTYPE_HPP
#define NAZARA_SHADER_EXPRESSIONTYPE_HPP
#ifndef NAZARA_SHADER_ASTTYPES_HPP
#define NAZARA_SHADER_ASTTYPES_HPP
#include <Nazara/Prerequisites.hpp>
#include <Nazara/Shader/ShaderEnums.hpp>
#include <string>
#include <variant>
#include <vector>
namespace Nz
namespace Nz::ShaderAst
{
using ShaderExpressionType = std::variant<ShaderNodes::BasicType, std::string>;
using ShaderExpressionType = std::variant<BasicType, std::string>;
struct StructDescription
{
struct StructMember
{
std::string name;
ShaderExpressionType type;
};
std::string name;
std::vector<StructMember> members;
};
inline bool IsBasicType(const ShaderExpressionType& type);
inline bool IsMatrixType(const ShaderExpressionType& type);
@ -22,6 +35,6 @@ namespace Nz
inline bool IsStructType(const ShaderExpressionType& type);
}
#include <Nazara/Shader/ShaderExpressionType.inl>
#include <Nazara/Shader/ShaderAstTypes.inl>
#endif // NAZARA_SHADER_EXPRESSIONTYPE_HPP
#endif // NAZARA_SHADER_ASTTYPES_HPP

View File

@ -2,18 +2,18 @@
// This file is part of the "Nazara Engine - Shader generator"
// For conditions of distribution and use, see copyright notice in Config.hpp
#include <Nazara/Shader/ShaderExpressionType.hpp>
#include <Nazara/Shader/ShaderAstTypes.hpp>
#include <Nazara/Core/Algorithm.hpp>
#include <Nazara/Shader/Debug.hpp>
namespace Nz
namespace Nz::ShaderAst
{
inline bool IsBasicType(const ShaderExpressionType& type)
{
return std::visit([&](auto&& arg)
{
using T = std::decay_t<decltype(arg)>;
if constexpr (std::is_same_v<T, ShaderNodes::BasicType>)
if constexpr (std::is_same_v<T, BasicType>)
return true;
else if constexpr (std::is_same_v<T, std::string>)
return false;
@ -25,8 +25,6 @@ namespace Nz
inline bool IsMatrixType(const ShaderExpressionType& type)
{
using namespace ShaderNodes;
if (!IsBasicType(type))
return false;
@ -58,8 +56,6 @@ namespace Nz
inline bool IsSamplerType(const ShaderExpressionType& type)
{
using namespace ShaderNodes;
if (!IsBasicType(type))
return false;
@ -94,7 +90,7 @@ namespace Nz
return std::visit([&](auto&& arg)
{
using T = std::decay_t<decltype(arg)>;
if constexpr (std::is_same_v<T, ShaderNodes::BasicType>)
if constexpr (std::is_same_v<T, BasicType>)
return false;
else if constexpr (std::is_same_v<T, std::string>)
return true;

View File

@ -10,14 +10,12 @@
#include <Nazara/Prerequisites.hpp>
#include <Nazara/Shader/Config.hpp>
#include <Nazara/Shader/ShaderEnums.hpp>
#include <Nazara/Shader/ShaderAstVisitorExcept.hpp>
#include <Nazara/Shader/ShaderAstExpressionVisitor.hpp>
#include <vector>
namespace Nz
namespace Nz::ShaderAst
{
class ShaderAst;
class NAZARA_SHADER_API ShaderAstValueCategory final : public ShaderAstVisitorExcept
class NAZARA_SHADER_API ShaderAstValueCategory final : public AstExpressionVisitor
{
public:
ShaderAstValueCategory() = default;
@ -25,28 +23,28 @@ namespace Nz
ShaderAstValueCategory(ShaderAstValueCategory&&) = delete;
~ShaderAstValueCategory() = default;
ShaderNodes::ExpressionCategory GetExpressionCategory(const ShaderNodes::ExpressionPtr& expression);
ExpressionCategory GetExpressionCategory(Expression& expression);
ShaderAstValueCategory& operator=(const ShaderAstValueCategory&) = delete;
ShaderAstValueCategory& operator=(ShaderAstValueCategory&&) = delete;
private:
using ShaderAstVisitorExcept::Visit;
void Visit(ShaderNodes::AccessMember& node) override;
void Visit(ShaderNodes::AssignOp& node) override;
void Visit(ShaderNodes::BinaryOp& node) override;
void Visit(ShaderNodes::Cast& node) override;
void Visit(ShaderNodes::ConditionalExpression& node) override;
void Visit(ShaderNodes::Constant& node) override;
void Visit(ShaderNodes::Identifier& node) override;
void Visit(ShaderNodes::IntrinsicCall& node) override;
void Visit(ShaderNodes::Sample2D& node) override;
void Visit(ShaderNodes::SwizzleOp& node) override;
using AstExpressionVisitor::Visit;
ShaderNodes::ExpressionCategory m_expressionCategory;
void Visit(AccessMemberExpression& node) override;
void Visit(AssignExpression& node) override;
void Visit(BinaryExpression& node) override;
void Visit(CastExpression& node) override;
void Visit(ConditionalExpression& node) override;
void Visit(ConstantExpression& node) override;
void Visit(IdentifierExpression& node) override;
void Visit(IntrinsicExpression& node) override;
void Visit(SwizzleExpression& node) override;
ExpressionCategory m_expressionCategory;
};
inline ShaderNodes::ExpressionCategory GetExpressionCategory(const ShaderNodes::ExpressionPtr& expression);
inline ExpressionCategory GetExpressionCategory(Expression& expression);
}
#include <Nazara/Shader/ShaderAstUtils.inl>

View File

@ -5,9 +5,9 @@
#include <Nazara/Shader/ShaderAstUtils.hpp>
#include <Nazara/Shader/Debug.hpp>
namespace Nz
namespace Nz::ShaderAst
{
ShaderNodes::ExpressionCategory GetExpressionCategory(const ShaderNodes::ExpressionPtr& expression)
ExpressionCategory GetExpressionCategory(Expression& expression)
{
ShaderAstValueCategory visitor;
return visitor.GetExpressionCategory(expression);

View File

@ -8,66 +8,62 @@
#define NAZARA_SHADERVALIDATOR_HPP
#include <Nazara/Prerequisites.hpp>
#include <Nazara/Core/ByteArray.hpp>
#include <Nazara/Core/ByteStream.hpp>
#include <Nazara/Shader/Config.hpp>
#include <Nazara/Shader/ShaderAst.hpp>
#include <Nazara/Shader/ShaderAstCache.hpp>
#include <Nazara/Shader/ShaderAstRecursiveVisitor.hpp>
#include <Nazara/Shader/ShaderVarVisitor.hpp>
#include <Nazara/Utility/Enums.hpp>
namespace Nz
namespace Nz::ShaderAst
{
class NAZARA_SHADER_API ShaderAstValidator : public ShaderAstRecursiveVisitor, public ShaderVarVisitor
class NAZARA_SHADER_API AstValidator : public AstRecursiveVisitor
{
public:
inline ShaderAstValidator(const ShaderAst& shader);
ShaderAstValidator(const ShaderAstValidator&) = delete;
ShaderAstValidator(ShaderAstValidator&&) = delete;
~ShaderAstValidator() = default;
inline AstValidator();
AstValidator(const AstValidator&) = delete;
AstValidator(AstValidator&&) = delete;
~AstValidator() = default;
bool Validate(std::string* error = nullptr);
bool Validate(StatementPtr& node, std::string* error = nullptr, AstCache* cache = nullptr);
private:
const ShaderNodes::ExpressionPtr& MandatoryExpr(const ShaderNodes::ExpressionPtr& node);
const ShaderNodes::NodePtr& MandatoryNode(const ShaderNodes::NodePtr& node);
void TypeMustMatch(const ShaderNodes::ExpressionPtr& left, const ShaderNodes::ExpressionPtr& right);
Expression& MandatoryExpr(ExpressionPtr& node);
Statement& MandatoryStatement(StatementPtr& node);
void TypeMustMatch(ExpressionPtr& left, ExpressionPtr& right);
void TypeMustMatch(const ShaderExpressionType& left, const ShaderExpressionType& right);
const ShaderAst::StructMember& CheckField(const std::string& structName, std::size_t* memberIndex, std::size_t remainingMembers);
ShaderExpressionType CheckField(const std::string& structName, const std::string* memberIdentifier, std::size_t remainingMembers);
using ShaderAstRecursiveVisitor::Visit;
void Visit(ShaderNodes::AccessMember& node) override;
void Visit(ShaderNodes::AssignOp& node) override;
void Visit(ShaderNodes::BinaryOp& node) override;
void Visit(ShaderNodes::Branch& node) override;
void Visit(ShaderNodes::Cast& node) override;
void Visit(ShaderNodes::ConditionalExpression& node) override;
void Visit(ShaderNodes::ConditionalStatement& node) override;
void Visit(ShaderNodes::Constant& node) override;
void Visit(ShaderNodes::DeclareVariable& node) override;
void Visit(ShaderNodes::ExpressionStatement& node) override;
void Visit(ShaderNodes::Identifier& node) override;
void Visit(ShaderNodes::IntrinsicCall& node) override;
void Visit(ShaderNodes::ReturnStatement& node) override;
void Visit(ShaderNodes::Sample2D& node) override;
void Visit(ShaderNodes::StatementBlock& node) override;
void Visit(ShaderNodes::SwizzleOp& node) override;
AstCache::Scope& EnterScope();
void ExitScope();
using ShaderVarVisitor::Visit;
void Visit(ShaderNodes::BuiltinVariable& var) override;
void Visit(ShaderNodes::InputVariable& var) override;
void Visit(ShaderNodes::LocalVariable& var) override;
void Visit(ShaderNodes::OutputVariable& var) override;
void Visit(ShaderNodes::ParameterVariable& var) override;
void Visit(ShaderNodes::UniformVariable& var) override;
void RegisterExpressionType(Expression& node, ShaderExpressionType expressionType);
void RegisterScope(Node& node);
void Visit(AccessMemberExpression& node) override;
void Visit(AssignExpression& node) override;
void Visit(BinaryExpression& node) override;
void Visit(CastExpression& node) override;
void Visit(ConditionalExpression& node) override;
void Visit(ConstantExpression& node) override;
void Visit(IdentifierExpression& node) override;
void Visit(IntrinsicExpression& node) override;
void Visit(SwizzleExpression& node) override;
void Visit(BranchStatement& node) override;
void Visit(ConditionalStatement& node) override;
void Visit(DeclareFunctionStatement& node) override;
void Visit(DeclareStructStatement& node) override;
void Visit(DeclareVariableStatement& node) override;
void Visit(ExpressionStatement& node) override;
void Visit(MultiStatement& node) override;
void Visit(ReturnStatement& node) override;
struct Context;
const ShaderAst& m_shader;
Context* m_context;
};
NAZARA_SHADER_API bool ValidateShader(const ShaderAst& shader, std::string* error = nullptr);
NAZARA_SHADER_API bool ValidateAst(StatementPtr& node, std::string* error = nullptr, AstCache* cache = nullptr);
}
#include <Nazara/Shader/ShaderAstValidator.inl>

View File

@ -5,10 +5,10 @@
#include <Nazara/Shader/ShaderAstValidator.hpp>
#include <Nazara/Shader/Debug.hpp>
namespace Nz
namespace Nz::ShaderAst
{
ShaderAstValidator::ShaderAstValidator(const ShaderAst& shader) :
m_shader(shader)
AstValidator::AstValidator() :
m_context(nullptr)
{
}
}

View File

@ -1,49 +0,0 @@
// Copyright (C) 2020 Jérôme Leclercq
// This file is part of the "Nazara Engine - Shader generator"
// For conditions of distribution and use, see copyright notice in Config.hpp
#pragma once
#ifndef NAZARA_SHADERASTVISITOR_HPP
#define NAZARA_SHADERASTVISITOR_HPP
#include <Nazara/Prerequisites.hpp>
#include <Nazara/Shader/Config.hpp>
#include <Nazara/Shader/ShaderNodes.hpp>
namespace Nz
{
class NAZARA_SHADER_API ShaderAstVisitor
{
public:
ShaderAstVisitor() = default;
ShaderAstVisitor(const ShaderAstVisitor&) = delete;
ShaderAstVisitor(ShaderAstVisitor&&) = delete;
virtual ~ShaderAstVisitor();
void Visit(const ShaderNodes::NodePtr& node);
virtual void Visit(ShaderNodes::AccessMember& node) = 0;
virtual void Visit(ShaderNodes::AssignOp& node) = 0;
virtual void Visit(ShaderNodes::BinaryOp& node) = 0;
virtual void Visit(ShaderNodes::Branch& node) = 0;
virtual void Visit(ShaderNodes::Cast& node) = 0;
virtual void Visit(ShaderNodes::ConditionalExpression& node) = 0;
virtual void Visit(ShaderNodes::ConditionalStatement& node) = 0;
virtual void Visit(ShaderNodes::Constant& node) = 0;
virtual void Visit(ShaderNodes::DeclareVariable& node) = 0;
virtual void Visit(ShaderNodes::Discard& node) = 0;
virtual void Visit(ShaderNodes::ExpressionStatement& node) = 0;
virtual void Visit(ShaderNodes::Identifier& node) = 0;
virtual void Visit(ShaderNodes::IntrinsicCall& node) = 0;
virtual void Visit(ShaderNodes::NoOp& node) = 0;
virtual void Visit(ShaderNodes::ReturnStatement& node) = 0;
virtual void Visit(ShaderNodes::Sample2D& node) = 0;
virtual void Visit(ShaderNodes::StatementBlock& node) = 0;
virtual void Visit(ShaderNodes::SwizzleOp& node) = 0;
ShaderAstVisitor& operator=(const ShaderAstVisitor&) = delete;
ShaderAstVisitor& operator=(ShaderAstVisitor&&) = delete;
};
}
#endif

View File

@ -1,41 +0,0 @@
// Copyright (C) 2020 Jérôme Leclercq
// This file is part of the "Nazara Engine - Shader generator"
// For conditions of distribution and use, see copyright notice in Config.hpp
#pragma once
#ifndef NAZARA_SHADERASTVISITOREXCEPT_HPP
#define NAZARA_SHADERASTVISITOREXCEPT_HPP
#include <Nazara/Prerequisites.hpp>
#include <Nazara/Shader/Config.hpp>
#include <Nazara/Shader/ShaderAstVisitor.hpp>
namespace Nz
{
class NAZARA_SHADER_API ShaderAstVisitorExcept : public ShaderAstVisitor
{
public:
using ShaderAstVisitor::Visit;
void Visit(ShaderNodes::AccessMember& node) override;
void Visit(ShaderNodes::AssignOp& node) override;
void Visit(ShaderNodes::BinaryOp& node) override;
void Visit(ShaderNodes::Branch& node) override;
void Visit(ShaderNodes::Cast& node) override;
void Visit(ShaderNodes::ConditionalExpression& node) override;
void Visit(ShaderNodes::ConditionalStatement& node) override;
void Visit(ShaderNodes::Constant& node) override;
void Visit(ShaderNodes::DeclareVariable& node) override;
void Visit(ShaderNodes::Discard& node) override;
void Visit(ShaderNodes::ExpressionStatement& node) override;
void Visit(ShaderNodes::Identifier& node) override;
void Visit(ShaderNodes::IntrinsicCall& node) override;
void Visit(ShaderNodes::NoOp& node) override;
void Visit(ShaderNodes::ReturnStatement& node) override;
void Visit(ShaderNodes::Sample2D& node) override;
void Visit(ShaderNodes::StatementBlock& node) override;
void Visit(ShaderNodes::SwizzleOp& node) override;
};
}
#endif

View File

@ -13,70 +13,60 @@
namespace Nz::ShaderBuilder
{
template<ShaderNodes::AssignType op>
struct AssignOpBuilder
namespace Impl
{
constexpr AssignOpBuilder() = default;
struct Binary
{
inline std::unique_ptr<ShaderAst::BinaryExpression> operator()(ShaderAst::BinaryType op, ShaderAst::ExpressionPtr left, ShaderAst::ExpressionPtr right) const;
};
std::shared_ptr<ShaderNodes::AssignOp> operator()(const ShaderNodes::ExpressionPtr& left, const ShaderNodes::ExpressionPtr& right) const;
};
struct Branch
{
inline std::unique_ptr<ShaderAst::BranchStatement> operator()(ShaderAst::ExpressionPtr condition, ShaderAst::StatementPtr truePath, ShaderAst::StatementPtr falsePath = nullptr) const;
inline std::unique_ptr<ShaderAst::BranchStatement> operator()(std::vector<ShaderAst::BranchStatement::ConditionalStatement> condStatements, ShaderAst::StatementPtr elseStatement = nullptr) const;
};
template<ShaderNodes::BinaryType op>
struct BinOpBuilder
{
constexpr BinOpBuilder() = default;
struct Constant
{
inline std::unique_ptr<ShaderAst::ConstantExpression> operator()(ShaderConstantValue value) const;
};
std::shared_ptr<ShaderNodes::BinaryOp> operator()(const ShaderNodes::ExpressionPtr& left, const ShaderNodes::ExpressionPtr& right) const;
};
struct DeclareFunction
{
inline std::unique_ptr<ShaderAst::DeclareFunctionStatement> operator()(std::string name, std::vector<ShaderAst::DeclareFunctionStatement::Parameter> parameters, std::vector<ShaderAst::StatementPtr> statements, ShaderAst::ShaderExpressionType returnType = ShaderAst::BasicType::Void) const;
};
struct BuiltinBuilder
{
constexpr BuiltinBuilder() = default;
struct DeclareVariable
{
inline std::unique_ptr<ShaderAst::DeclareVariableStatement> operator()(std::string name, ShaderAst::ShaderExpressionType type, ShaderAst::ExpressionPtr initialValue = nullptr) const;
};
inline std::shared_ptr<ShaderNodes::Variable> operator()(ShaderNodes::BuiltinEntry builtin) const;
};
struct Identifier
{
inline std::unique_ptr<ShaderAst::IdentifierExpression> operator()(std::string name) const;
};
template<typename T>
struct GenBuilder
{
constexpr GenBuilder() = default;
struct Return
{
inline std::unique_ptr<ShaderAst::ReturnStatement> operator()(ShaderAst::ExpressionPtr expr = nullptr) const;
};
template<typename... Args> std::shared_ptr<T> operator()(Args&&... args) const;
};
template<typename T>
struct NoParam
{
std::unique_ptr<T> operator()() const;
};
}
constexpr GenBuilder<ShaderNodes::AccessMember> AccessMember;
constexpr BinOpBuilder<ShaderNodes::BinaryType::Add> Add;
constexpr AssignOpBuilder<ShaderNodes::AssignType::Simple> Assign;
constexpr BuiltinBuilder Builtin;
constexpr GenBuilder<ShaderNodes::StatementBlock> Block;
constexpr GenBuilder<ShaderNodes::Branch> Branch;
constexpr GenBuilder<ShaderNodes::ConditionalExpression> ConditionalExpression;
constexpr GenBuilder<ShaderNodes::ConditionalStatement> ConditionalStatement;
constexpr GenBuilder<ShaderNodes::Constant> Constant;
constexpr GenBuilder<ShaderNodes::DeclareVariable> DeclareVariable;
constexpr GenBuilder<ShaderNodes::Discard> Discard;
constexpr BinOpBuilder<ShaderNodes::BinaryType::Divide> Division;
constexpr BinOpBuilder<ShaderNodes::BinaryType::CompEq> Equal;
constexpr BinOpBuilder<ShaderNodes::BinaryType::CompGt> GreaterThan;
constexpr BinOpBuilder<ShaderNodes::BinaryType::CompGe> GreaterThanOrEqual;
constexpr BinOpBuilder<ShaderNodes::BinaryType::CompLt> LessThan;
constexpr BinOpBuilder<ShaderNodes::BinaryType::CompLe> LessThanOrEqual;
constexpr BinOpBuilder<ShaderNodes::BinaryType::CompNe> NotEqual;
constexpr GenBuilder<ShaderNodes::ExpressionStatement> ExprStatement;
constexpr GenBuilder<ShaderNodes::Identifier> Identifier;
constexpr GenBuilder<ShaderNodes::IntrinsicCall> IntrinsicCall;
constexpr GenBuilder<ShaderNodes::InputVariable> Input;
constexpr GenBuilder<ShaderNodes::LocalVariable> Local;
constexpr BinOpBuilder<ShaderNodes::BinaryType::Multiply> Multiply;
constexpr GenBuilder<ShaderNodes::OutputVariable> Output;
constexpr GenBuilder<ShaderNodes::ParameterVariable> Parameter;
constexpr GenBuilder<ShaderNodes::Sample2D> Sample2D;
constexpr GenBuilder<ShaderNodes::StatementBlock> StatementBlock;
constexpr GenBuilder<ShaderNodes::SwizzleOp> Swizzle;
constexpr BinOpBuilder<ShaderNodes::BinaryType::Subtract> Subtract;
constexpr GenBuilder<ShaderNodes::UniformVariable> Uniform;
template<ShaderNodes::BasicType Type, typename... Args> std::shared_ptr<ShaderNodes::Cast> Cast(Args&&... args);
constexpr Impl::Binary Binary;
constexpr Impl::Branch Branch;
constexpr Impl::Constant Constant;
constexpr Impl::DeclareFunction DeclareFunction;
constexpr Impl::DeclareVariable DeclareVariable;
constexpr Impl::NoParam<ShaderAst::DiscardStatement> Discard;
constexpr Impl::Identifier Identifier;
constexpr Impl::NoParam<ShaderAst::NoOpStatement> NoOp;
constexpr Impl::Return Return;
}
#include <Nazara/Shader/ShaderBuilder.inl>

View File

@ -7,45 +7,87 @@
namespace Nz::ShaderBuilder
{
inline std::unique_ptr<ShaderAst::BinaryExpression> Impl::Binary::operator()(ShaderAst::BinaryType op, ShaderAst::ExpressionPtr left, ShaderAst::ExpressionPtr right) const
{
auto constantNode = std::make_unique<ShaderAst::BinaryExpression>();
constantNode->op = op;
constantNode->left = std::move(left);
constantNode->right = std::move(right);
return constantNode;
}
inline std::unique_ptr<ShaderAst::BranchStatement> Impl::Branch::operator()(ShaderAst::ExpressionPtr condition, ShaderAst::StatementPtr truePath, ShaderAst::StatementPtr falsePath) const
{
auto branchNode = std::make_unique<ShaderAst::BranchStatement>();
auto& condStatement = branchNode->condStatements.emplace_back();
condStatement.condition = std::move(condition);
condStatement.statement = std::move(truePath);
branchNode->elseStatement = std::move(falsePath);
return branchNode;
}
inline std::unique_ptr<ShaderAst::BranchStatement> Impl::Branch::operator()(std::vector<ShaderAst::BranchStatement::ConditionalStatement> condStatements, ShaderAst::StatementPtr elseStatement) const
{
auto branchNode = std::make_unique<ShaderAst::BranchStatement>();
branchNode->condStatements = std::move(condStatements);
branchNode->elseStatement = std::move(elseStatement);
return branchNode;
}
inline std::unique_ptr<ShaderAst::ConstantExpression> Impl::Constant::operator()(ShaderConstantValue value) const
{
auto constantNode = std::make_unique<ShaderAst::ConstantExpression>();
constantNode->value = std::move(value);
return constantNode;
}
inline std::unique_ptr<ShaderAst::DeclareFunctionStatement> Impl::DeclareFunction::operator()(std::string name, std::vector<ShaderAst::DeclareFunctionStatement::Parameter> parameters, std::vector<ShaderAst::StatementPtr> statements, ShaderAst::ShaderExpressionType returnType) const
{
auto declareFunctionNode = std::make_unique<ShaderAst::DeclareFunctionStatement>();
declareFunctionNode->name = std::move(name);
declareFunctionNode->parameters = std::move(parameters);
declareFunctionNode->returnType = std::move(returnType);
declareFunctionNode->statements = std::move(statements);
return declareFunctionNode;
}
inline std::unique_ptr<ShaderAst::DeclareVariableStatement> Nz::ShaderBuilder::Impl::DeclareVariable::operator()(std::string name, ShaderAst::ShaderExpressionType type, ShaderAst::ExpressionPtr initialValue) const
{
auto declareVariableNode = std::make_unique<ShaderAst::DeclareVariableStatement>();
declareVariableNode->varName = std::move(name);
declareVariableNode->varType = std::move(type);
declareVariableNode->initialExpression = std::move(initialValue);
return declareVariableNode;
}
inline std::unique_ptr<ShaderAst::IdentifierExpression> Impl::Identifier::operator()(std::string name) const
{
auto identifierNode = std::make_unique<ShaderAst::IdentifierExpression>();
identifierNode->identifier = std::move(name);
return identifierNode;
}
inline std::unique_ptr<ShaderAst::ReturnStatement> Impl::Return::operator()(ShaderAst::ExpressionPtr expr) const
{
auto returnNode = std::make_unique<ShaderAst::ReturnStatement>();
returnNode->returnExpr = std::move(expr);
return returnNode;
}
template<typename T>
template<typename... Args>
std::shared_ptr<T> GenBuilder<T>::operator()(Args&&... args) const
std::unique_ptr<T> Impl::NoParam<T>::operator()() const
{
return T::Build(std::forward<Args>(args)...);
}
template<ShaderNodes::AssignType op>
std::shared_ptr<ShaderNodes::AssignOp> AssignOpBuilder<op>::operator()(const ShaderNodes::ExpressionPtr& left, const ShaderNodes::ExpressionPtr& right) const
{
return ShaderNodes::AssignOp::Build(op, left, right);
}
template<ShaderNodes::BinaryType op>
std::shared_ptr<ShaderNodes::BinaryOp> BinOpBuilder<op>::operator()(const ShaderNodes::ExpressionPtr& left, const ShaderNodes::ExpressionPtr& right) const
{
return ShaderNodes::BinaryOp::Build(op, left, right);
}
inline std::shared_ptr<ShaderNodes::Variable> BuiltinBuilder::operator()(ShaderNodes::BuiltinEntry builtin) const
{
ShaderNodes::BasicType exprType = ShaderNodes::BasicType::Void;
switch (builtin)
{
case ShaderNodes::BuiltinEntry::VertexPosition:
exprType = ShaderNodes::BasicType::Float4;
break;
}
NazaraAssert(exprType != ShaderNodes::BasicType::Void, "Unhandled builtin");
return ShaderNodes::BuiltinVariable::Build(builtin, exprType);
}
template<ShaderNodes::BasicType Type, typename... Args>
std::shared_ptr<ShaderNodes::Cast> Cast(Args&&... args)
{
return ShaderNodes::Cast::Build(Type, std::forward<Args>(args)...);
return std::make_unique<T>();
}
}

View File

@ -9,7 +9,7 @@
#include <Nazara/Prerequisites.hpp>
namespace Nz::ShaderNodes
namespace Nz::ShaderAst
{
enum class AssignType
{
@ -77,35 +77,9 @@ namespace Nz::ShaderNodes
{
None = -1,
AccessMember,
AssignOp,
BinaryOp,
Branch,
Cast,
Constant,
ConditionalExpression,
ConditionalStatement,
DeclareVariable,
Discard,
ExpressionStatement,
Identifier,
IntrinsicCall,
NoOp,
ReturnStatement,
Sample2D,
SwizzleOp,
StatementBlock,
Max = StatementBlock
};
enum class SsaInstruction
{
OpAdd,
OpDiv,
OpMul,
OpSub,
OpSample
#define NAZARA_SHADERAST_NODE(Node) Node,
#define NAZARA_SHADERAST_STATEMENT_LAST(Node) Node, Max = Node
#include <Nazara/Shader/ShaderAstNodes.hpp>
};
enum class SwizzleComponent
@ -127,6 +101,11 @@ namespace Nz::ShaderNodes
ParameterVariable,
UniformVariable
};
inline std::size_t GetComponentCount(BasicType type);
inline BasicType GetComponentType(BasicType type);
}
#include <Nazara/Shader/ShaderEnums.inl>
#endif // NAZARA_SHADER_ENUMS_HPP

View File

@ -0,0 +1,57 @@
// Copyright (C) 2020 Jérôme Leclercq
// This file is part of the "Nazara Engine - Shader generator"
// For conditions of distribution and use, see copyright notice in Config.hpp
#include <Nazara/Shader/ShaderEnums.hpp>
#include <Nazara/Shader/Debug.hpp>
namespace Nz::ShaderAst
{
inline std::size_t GetComponentCount(BasicType type)
{
switch (type)
{
case BasicType::Float2:
case BasicType::Int2:
return 2;
case BasicType::Float3:
case BasicType::Int3:
return 3;
case BasicType::Float4:
case BasicType::Int4:
return 4;
case BasicType::Mat4x4:
return 4;
default:
return 1;
}
}
inline BasicType GetComponentType(BasicType type)
{
switch (type)
{
case BasicType::Float2:
case BasicType::Float3:
case BasicType::Float4:
return BasicType::Float1;
case BasicType::Int2:
case BasicType::Int3:
case BasicType::Int4:
return BasicType::Int1;
case BasicType::Mat4x4:
return BasicType::Float4;
default:
return type;
}
}
}
#include <Nazara/Shader/DebugOff.hpp>

View File

@ -10,7 +10,7 @@
#include <Nazara/Prerequisites.hpp>
#include <Nazara/Shader/Config.hpp>
#include <Nazara/Shader/ShaderLangLexer.hpp>
#include <Nazara/Shader/ShaderAst.hpp>
#include <Nazara/Shader/ShaderNodes.hpp>
namespace Nz::ShaderLang
{
@ -44,7 +44,7 @@ namespace Nz::ShaderLang
inline Parser();
~Parser() = default;
ShaderAst Parse(const std::vector<Token>& tokens);
ShaderAst::StatementPtr Parse(const std::vector<Token>& tokens);
private:
// Flow control
@ -54,29 +54,30 @@ namespace Nz::ShaderLang
const Token& PeekNext();
// Statements
ShaderNodes::StatementPtr ParseFunctionBody();
void ParseFunctionDeclaration();
ShaderAst::FunctionParameter ParseFunctionParameter();
ShaderNodes::StatementPtr ParseReturnStatement();
ShaderNodes::StatementPtr ParseStatement();
ShaderNodes::StatementPtr ParseStatementList();
std::vector<ShaderAst::StatementPtr> ParseFunctionBody();
ShaderAst::StatementPtr ParseFunctionDeclaration();
ShaderAst::DeclareFunctionStatement::Parameter ParseFunctionParameter();
ShaderAst::StatementPtr ParseReturnStatement();
ShaderAst::StatementPtr ParseStatement();
std::vector<ShaderAst::StatementPtr> ParseStatementList();
ShaderAst::StatementPtr ParseVariableDeclaration();
// Expressions
ShaderNodes::ExpressionPtr ParseBinOpRhs(int exprPrecedence, ShaderNodes::ExpressionPtr lhs);
ShaderNodes::ExpressionPtr ParseExpression();
ShaderNodes::ExpressionPtr ParseIdentifier();
ShaderNodes::ExpressionPtr ParseIntegerExpression();
ShaderNodes::ExpressionPtr ParseParenthesisExpression();
ShaderNodes::ExpressionPtr ParsePrimaryExpression();
ShaderAst::ExpressionPtr ParseBinOpRhs(int exprPrecedence, ShaderAst::ExpressionPtr lhs);
ShaderAst::ExpressionPtr ParseExpression();
ShaderAst::ExpressionPtr ParseIdentifier();
ShaderAst::ExpressionPtr ParseIntegerExpression();
ShaderAst::ExpressionPtr ParseParenthesisExpression();
ShaderAst::ExpressionPtr ParsePrimaryExpression();
std::string ParseIdentifierAsName();
ShaderExpressionType ParseIdentifierAsType();
ShaderAst::ShaderExpressionType ParseIdentifierAsType();
static int GetTokenPrecedence(TokenType token);
struct Context
{
ShaderAst result;
std::unique_ptr<ShaderAst::MultiStatement> root;
std::size_t tokenCount;
std::size_t tokenIndex = 0;
const Token* tokens;

View File

@ -6,10 +6,11 @@
#error You must define NAZARA_SHADERLANG_TOKEN before including this file
#endif
#ifndef NAZARA_SHADERLANG_TOKENT_LAST
#ifndef NAZARA_SHADERLANG_TOKEN_LAST
#define NAZARA_SHADERLANG_TOKEN_LAST(X) NAZARA_SHADERLANG_TOKEN(X)
#endif
NAZARA_SHADERLANG_TOKEN(Assign)
NAZARA_SHADERLANG_TOKEN(BoolFalse)
NAZARA_SHADERLANG_TOKEN(BoolTrue)
NAZARA_SHADERLANG_TOKEN(ClosingParenthesis)
@ -24,6 +25,7 @@ NAZARA_SHADERLANG_TOKEN(FunctionDeclaration)
NAZARA_SHADERLANG_TOKEN(FunctionReturn)
NAZARA_SHADERLANG_TOKEN(IntegerValue)
NAZARA_SHADERLANG_TOKEN(Identifier)
NAZARA_SHADERLANG_TOKEN(Let)
NAZARA_SHADERLANG_TOKEN(Multiply)
NAZARA_SHADERLANG_TOKEN(Minus)
NAZARA_SHADERLANG_TOKEN(Plus)

View File

@ -14,308 +14,245 @@
#include <Nazara/Shader/Config.hpp>
#include <Nazara/Shader/ShaderConstantValue.hpp>
#include <Nazara/Shader/ShaderEnums.hpp>
#include <Nazara/Shader/ShaderExpressionType.hpp>
#include <Nazara/Shader/ShaderVariables.hpp>
#include <Nazara/Shader/ShaderAstTypes.hpp>
#include <array>
#include <memory>
#include <optional>
#include <string>
namespace Nz
namespace Nz::ShaderAst
{
class ShaderAstVisitor;
class AstExpressionVisitor;
class AstStatementVisitor;
namespace ShaderNodes
struct NAZARA_SHADER_API Node
{
class Node;
Node() = default;
Node(const Node&) = delete;
Node(Node&&) noexcept = default;
virtual ~Node();
using NodePtr = std::shared_ptr<Node>;
virtual NodeType GetType() const = 0;
class NAZARA_SHADER_API Node
Node& operator=(const Node&) = delete;
Node& operator=(Node&&) noexcept = default;
};
// Expressions
struct Expression;
using ExpressionPtr = std::unique_ptr<Expression>;
struct NAZARA_SHADER_API Expression : Node
{
Expression() = default;
Expression(const Expression&) = delete;
Expression(Expression&&) noexcept = default;
~Expression() = default;
virtual void Visit(AstExpressionVisitor& visitor) = 0;
Expression& operator=(const Expression&) = delete;
Expression& operator=(Expression&&) noexcept = default;
};
struct NAZARA_SHADER_API AccessMemberExpression : public Expression
{
NodeType GetType() const override;
void Visit(AstExpressionVisitor& visitor) override;
ExpressionPtr structExpr;
std::vector<std::string> memberIdentifiers;
};
struct NAZARA_SHADER_API AssignExpression : public Expression
{
NodeType GetType() const override;
void Visit(AstExpressionVisitor& visitor) override;
AssignType op;
ExpressionPtr left;
ExpressionPtr right;
};
struct NAZARA_SHADER_API BinaryExpression : public Expression
{
NodeType GetType() const override;
void Visit(AstExpressionVisitor& visitor) override;
BinaryType op;
ExpressionPtr left;
ExpressionPtr right;
};
struct NAZARA_SHADER_API CastExpression : public Expression
{
NodeType GetType() const override;
void Visit(AstExpressionVisitor& visitor) override;
BasicType targetType;
std::array<ExpressionPtr, 4> expressions;
};
struct NAZARA_SHADER_API ConditionalExpression : public Expression
{
NodeType GetType() const override;
void Visit(AstExpressionVisitor& visitor) override;
std::string conditionName;
ExpressionPtr falsePath;
ExpressionPtr truePath;
};
struct NAZARA_SHADER_API ConstantExpression : public Expression
{
NodeType GetType() const override;
void Visit(AstExpressionVisitor& visitor) override;
ShaderConstantValue value;
};
struct NAZARA_SHADER_API IdentifierExpression : public Expression
{
NodeType GetType() const override;
void Visit(AstExpressionVisitor& visitor) override;
std::string identifier;
};
struct NAZARA_SHADER_API IntrinsicExpression : public Expression
{
NodeType GetType() const override;
void Visit(AstExpressionVisitor& visitor) override;
IntrinsicType intrinsic;
std::vector<ExpressionPtr> parameters;
};
struct NAZARA_SHADER_API SwizzleExpression : public Expression
{
NodeType GetType() const override;
void Visit(AstExpressionVisitor& visitor) override;
std::array<SwizzleComponent, 4> components;
std::size_t componentCount;
ExpressionPtr expression;
};
// Statements
struct Statement;
using StatementPtr = std::unique_ptr<Statement>;
struct NAZARA_SHADER_API Statement : Node
{
Statement() = default;
Statement(const Statement&) = delete;
Statement(Statement&&) noexcept = default;
~Statement() = default;
virtual void Visit(AstStatementVisitor& visitor) = 0;
Statement& operator=(const Statement&) = delete;
Statement& operator=(Statement&&) noexcept = default;
};
struct NAZARA_SHADER_API BranchStatement : public Statement
{
NodeType GetType() const override;
void Visit(AstStatementVisitor& visitor) override;
struct ConditionalStatement
{
public:
virtual ~Node();
inline NodeType GetType() const;
inline bool IsStatement() const;
virtual void Visit(ShaderAstVisitor& visitor) = 0;
static inline unsigned int GetComponentCount(BasicType type);
static inline BasicType GetComponentType(BasicType type);
protected:
inline Node(NodeType type, bool isStatement);
private:
NodeType m_type;
bool m_isStatement;
};
class Expression;
using ExpressionPtr = std::shared_ptr<Expression>;
class NAZARA_SHADER_API Expression : public Node, public std::enable_shared_from_this<Expression>
{
public:
inline Expression(NodeType type);
virtual ShaderExpressionType GetExpressionType() const = 0;
ExpressionPtr condition;
StatementPtr statement;
};
class Statement;
std::vector<ConditionalStatement> condStatements;
StatementPtr elseStatement;
};
using StatementPtr = std::shared_ptr<Statement>;
struct NAZARA_SHADER_API ConditionalStatement : Statement
{
NodeType GetType() const override;
void Visit(AstStatementVisitor& visitor) override;
class NAZARA_SHADER_API Statement : public Node, public std::enable_shared_from_this<Statement>
std::string conditionName;
StatementPtr statement;
};
struct NAZARA_SHADER_API DeclareFunctionStatement : Statement
{
NodeType GetType() const override;
void Visit(AstStatementVisitor& visitor) override;
struct Parameter
{
public:
inline Statement(NodeType type);
std::string name;
ShaderExpressionType type;
};
struct NAZARA_SHADER_API ExpressionStatement : public Statement
{
inline ExpressionStatement();
void Visit(ShaderAstVisitor& visitor) override;
ExpressionPtr expression;
static inline std::shared_ptr<ExpressionStatement> Build(ExpressionPtr expr);
};
//////////////////////////////////////////////////////////////////////////
struct NAZARA_SHADER_API ConditionalStatement : public Statement
{
inline ConditionalStatement();
void Visit(ShaderAstVisitor& visitor) override;
std::string conditionName;
StatementPtr statement;
static inline std::shared_ptr<ConditionalStatement> Build(std::string condition, StatementPtr statementPtr);
};
struct NAZARA_SHADER_API StatementBlock : public Statement
{
inline StatementBlock();
void Visit(ShaderAstVisitor& visitor) override;
std::vector<StatementPtr> statements;
static inline std::shared_ptr<StatementBlock> Build(std::vector<StatementPtr> statements);
template<typename... Args> static std::shared_ptr<StatementBlock> Build(Args&&... args);
};
struct NAZARA_SHADER_API DeclareVariable : public Statement
{
inline DeclareVariable();
void Visit(ShaderAstVisitor& visitor) override;
ExpressionPtr expression;
VariablePtr variable;
static inline std::shared_ptr<DeclareVariable> Build(VariablePtr variable, ExpressionPtr expression = nullptr);
};
struct NAZARA_SHADER_API Discard : public Statement
{
inline Discard();
void Visit(ShaderAstVisitor& visitor) override;
static inline std::shared_ptr<Discard> Build();
};
struct NAZARA_SHADER_API Identifier : public Expression
{
inline Identifier();
ShaderExpressionType GetExpressionType() const override;
void Visit(ShaderAstVisitor& visitor) override;
VariablePtr var;
static inline std::shared_ptr<Identifier> Build(VariablePtr variable);
};
struct NAZARA_SHADER_API AccessMember : public Expression
{
inline AccessMember();
ShaderExpressionType GetExpressionType() const override;
void Visit(ShaderAstVisitor& visitor) override;
ExpressionPtr structExpr;
ShaderExpressionType exprType;
std::vector<std::size_t> memberIndices;
static inline std::shared_ptr<AccessMember> Build(ExpressionPtr structExpr, std::size_t memberIndex, ShaderExpressionType exprType);
static inline std::shared_ptr<AccessMember> Build(ExpressionPtr structExpr, std::vector<std::size_t> memberIndices, ShaderExpressionType exprType);
};
struct NAZARA_SHADER_API NoOp : public Statement
{
inline NoOp();
void Visit(ShaderAstVisitor& visitor) override;
static inline std::shared_ptr<NoOp> Build();
};
struct NAZARA_SHADER_API ReturnStatement : public Statement
{
inline ReturnStatement();
void Visit(ShaderAstVisitor& visitor) override;
ExpressionPtr returnExpr;
static inline std::shared_ptr<ReturnStatement> Build(ExpressionPtr expr = nullptr);
};
//////////////////////////////////////////////////////////////////////////
struct NAZARA_SHADER_API AssignOp : public Expression
{
inline AssignOp();
ShaderExpressionType GetExpressionType() const override;
void Visit(ShaderAstVisitor& visitor) override;
AssignType op;
ExpressionPtr left;
ExpressionPtr right;
static inline std::shared_ptr<AssignOp> Build(AssignType op, ExpressionPtr left, ExpressionPtr right);
};
struct NAZARA_SHADER_API BinaryOp : public Expression
{
inline BinaryOp();
ShaderExpressionType GetExpressionType() const override;
void Visit(ShaderAstVisitor& visitor) override;
BinaryType op;
ExpressionPtr left;
ExpressionPtr right;
static inline std::shared_ptr<BinaryOp> Build(BinaryType op, ExpressionPtr left, ExpressionPtr right);
};
struct NAZARA_SHADER_API Branch : public Statement
{
struct ConditionalStatement;
inline Branch();
void Visit(ShaderAstVisitor& visitor) override;
std::vector<ConditionalStatement> condStatements;
StatementPtr elseStatement;
struct ConditionalStatement
{
ExpressionPtr condition;
StatementPtr statement;
};
static inline std::shared_ptr<Branch> Build(ExpressionPtr condition, StatementPtr trueStatement, StatementPtr falseStatement = nullptr);
static inline std::shared_ptr<Branch> Build(std::vector<ConditionalStatement> statements, StatementPtr elseStatement = nullptr);
};
struct NAZARA_SHADER_API Cast : public Expression
{
inline Cast();
ShaderExpressionType GetExpressionType() const override;
void Visit(ShaderAstVisitor& visitor) override;
BasicType exprType;
std::array<ExpressionPtr, 4> expressions;
static inline std::shared_ptr<Cast> Build(BasicType castTo, ExpressionPtr first, ExpressionPtr second = nullptr, ExpressionPtr third = nullptr, ExpressionPtr fourth = nullptr);
static inline std::shared_ptr<Cast> Build(BasicType castTo, ExpressionPtr* expressions, std::size_t expressionCount);
};
struct NAZARA_SHADER_API ConditionalExpression : public Expression
{
inline ConditionalExpression();
ShaderExpressionType GetExpressionType() const override;
void Visit(ShaderAstVisitor& visitor) override;
std::string conditionName;
ExpressionPtr falsePath;
ExpressionPtr truePath;
static inline std::shared_ptr<ConditionalExpression> Build(std::string condition, ExpressionPtr truePath, ExpressionPtr falsePath);
};
struct NAZARA_SHADER_API Constant : public Expression
{
inline Constant();
ShaderExpressionType GetExpressionType() const override;
void Visit(ShaderAstVisitor& visitor) override;
ShaderConstantValue value;
template<typename T> static std::shared_ptr<Constant> Build(const T& value);
};
struct NAZARA_SHADER_API SwizzleOp : public Expression
{
inline SwizzleOp();
ShaderExpressionType GetExpressionType() const override;
void Visit(ShaderAstVisitor& visitor) override;
std::array<SwizzleComponent, 4> components;
std::size_t componentCount;
ExpressionPtr expression;
static inline std::shared_ptr<SwizzleOp> Build(ExpressionPtr expressionPtr, SwizzleComponent swizzleComponent);
static inline std::shared_ptr<SwizzleOp> Build(ExpressionPtr expressionPtr, std::initializer_list<SwizzleComponent> swizzleComponents);
static inline std::shared_ptr<SwizzleOp> Build(ExpressionPtr expressionPtr, const SwizzleComponent* components, std::size_t componentCount);
};
//////////////////////////////////////////////////////////////////////////
struct NAZARA_SHADER_API Sample2D : public Expression
{
inline Sample2D();
ShaderExpressionType GetExpressionType() const override;
void Visit(ShaderAstVisitor& visitor) override;
ExpressionPtr sampler;
ExpressionPtr coordinates;
static inline std::shared_ptr<Sample2D> Build(ExpressionPtr samplerPtr, ExpressionPtr coordinatesPtr);
};
//////////////////////////////////////////////////////////////////////////
struct NAZARA_SHADER_API IntrinsicCall : public Expression
{
inline IntrinsicCall();
ShaderExpressionType GetExpressionType() const override;
void Visit(ShaderAstVisitor& visitor) override;
IntrinsicType intrinsic;
std::vector<ExpressionPtr> parameters;
static inline std::shared_ptr<IntrinsicCall> Build(IntrinsicType intrinsic, std::vector<ExpressionPtr> parameters);
};
}
std::string name;
std::vector<Parameter> parameters;
std::vector<StatementPtr> statements;
ShaderExpressionType returnType = BasicType::Void;
};
struct NAZARA_SHADER_API DeclareStructStatement : Statement
{
NodeType GetType() const override;
void Visit(AstStatementVisitor& visitor) override;
StructDescription description;
};
struct NAZARA_SHADER_API DeclareVariableStatement : Statement
{
NodeType GetType() const override;
void Visit(AstStatementVisitor& visitor) override;
std::string varName;
ExpressionPtr initialExpression;
ShaderExpressionType varType;
};
struct NAZARA_SHADER_API DiscardStatement : Statement
{
NodeType GetType() const override;
void Visit(AstStatementVisitor& visitor) override;
};
struct NAZARA_SHADER_API ExpressionStatement : Statement
{
NodeType GetType() const override;
void Visit(AstStatementVisitor& visitor) override;
ExpressionPtr expression;
};
struct NAZARA_SHADER_API MultiStatement : Statement
{
NodeType GetType() const override;
void Visit(AstStatementVisitor& visitor) override;
std::vector<StatementPtr> statements;
};
struct NAZARA_SHADER_API NoOpStatement : Statement
{
NodeType GetType() const override;
void Visit(AstStatementVisitor& visitor) override;
};
struct NAZARA_SHADER_API ReturnStatement : Statement
{
NodeType GetType() const override;
void Visit(AstStatementVisitor& visitor) override;
ExpressionPtr returnExpr;
};
}
#include <Nazara/Shader/ShaderNodes.inl>

View File

@ -5,394 +5,8 @@
#include <Nazara/Shader/ShaderNodes.hpp>
#include <Nazara/Shader/Debug.hpp>
namespace Nz::ShaderNodes
namespace Nz::ShaderAst
{
inline Node::Node(NodeType type, bool isStatement) :
m_type(type),
m_isStatement(isStatement)
{
}
inline NodeType ShaderNodes::Node::GetType() const
{
return m_type;
}
inline bool Node::IsStatement() const
{
return m_isStatement;
}
inline unsigned int Node::GetComponentCount(BasicType type)
{
switch (type)
{
case BasicType::Float2:
case BasicType::Int2:
return 2;
case BasicType::Float3:
case BasicType::Int3:
return 3;
case BasicType::Float4:
case BasicType::Int4:
return 4;
case BasicType::Mat4x4:
return 4;
default:
return 1;
}
}
inline BasicType Node::GetComponentType(BasicType type)
{
switch (type)
{
case BasicType::Float2:
case BasicType::Float3:
case BasicType::Float4:
return BasicType::Float1;
case BasicType::Int2:
case BasicType::Int3:
case BasicType::Int4:
return BasicType::Int1;
case BasicType::Mat4x4:
return BasicType::Float4;
default:
return type;
}
}
inline Expression::Expression(NodeType type) :
Node(type, false)
{
}
inline Statement::Statement(NodeType type) :
Node(type, true)
{
}
inline ExpressionStatement::ExpressionStatement() :
Statement(NodeType::ExpressionStatement)
{
}
inline std::shared_ptr<ExpressionStatement> ExpressionStatement::Build(ExpressionPtr expr)
{
auto node = std::make_shared<ExpressionStatement>();
node->expression = std::move(expr);
return node;
}
inline ConditionalStatement::ConditionalStatement() :
Statement(NodeType::ConditionalStatement)
{
}
inline std::shared_ptr<ConditionalStatement> ConditionalStatement::Build(std::string condition, StatementPtr statementPtr)
{
auto node = std::make_shared<ConditionalStatement>();
node->conditionName = std::move(condition);
node->statement = std::move(statementPtr);
return node;
}
inline StatementBlock::StatementBlock() :
Statement(NodeType::StatementBlock)
{
}
inline std::shared_ptr<StatementBlock> StatementBlock::Build(std::vector<StatementPtr> statements)
{
auto node = std::make_shared<StatementBlock>();
node->statements = std::move(statements);
return node;
}
template<typename... Args>
std::shared_ptr<StatementBlock> StatementBlock::Build(Args&&... args)
{
auto node = std::make_shared<StatementBlock>();
node->statements = std::vector<StatementPtr>({ std::forward<Args>(args)... });
return node;
}
inline DeclareVariable::DeclareVariable() :
Statement(NodeType::DeclareVariable)
{
}
inline std::shared_ptr<DeclareVariable> DeclareVariable::Build(VariablePtr variable, ExpressionPtr expression)
{
auto node = std::make_shared<DeclareVariable>();
node->expression = std::move(expression);
node->variable = std::move(variable);
return node;
}
inline Discard::Discard() :
Statement(NodeType::Discard)
{
}
inline std::shared_ptr<Discard> Discard::Build()
{
return std::make_shared<Discard>();
}
inline Identifier::Identifier() :
Expression(NodeType::Identifier)
{
}
inline std::shared_ptr<Identifier> Identifier::Build(VariablePtr variable)
{
auto node = std::make_shared<Identifier>();
node->var = std::move(variable);
return node;
}
inline AccessMember::AccessMember() :
Expression(NodeType::AccessMember)
{
}
inline std::shared_ptr<AccessMember> AccessMember::Build(ExpressionPtr structExpr, std::size_t memberIndex, ShaderExpressionType exprType)
{
return Build(std::move(structExpr), std::vector<std::size_t>{ memberIndex }, exprType);
}
inline std::shared_ptr<AccessMember> AccessMember::Build(ExpressionPtr structExpr, std::vector<std::size_t> memberIndices, ShaderExpressionType exprType)
{
auto node = std::make_shared<AccessMember>();
node->exprType = std::move(exprType);
node->memberIndices = std::move(memberIndices);
node->structExpr = std::move(structExpr);
return node;
}
inline NoOp::NoOp() :
Statement(NodeType::NoOp)
{
}
inline std::shared_ptr<NoOp> NoOp::Build()
{
return std::make_shared<NoOp>();
}
inline ReturnStatement::ReturnStatement() :
Statement(NodeType::ReturnStatement)
{
}
inline std::shared_ptr<ReturnStatement> ShaderNodes::ReturnStatement::Build(ExpressionPtr expr)
{
auto node = std::make_shared<ReturnStatement>();
node->returnExpr = std::move(expr);
return node;
}
inline AssignOp::AssignOp() :
Expression(NodeType::AssignOp)
{
}
inline std::shared_ptr<AssignOp> AssignOp::Build(AssignType op, ExpressionPtr left, ExpressionPtr right)
{
auto node = std::make_shared<AssignOp>();
node->op = op;
node->left = std::move(left);
node->right = std::move(right);
return node;
}
inline BinaryOp::BinaryOp() :
Expression(NodeType::BinaryOp)
{
}
inline std::shared_ptr<BinaryOp> BinaryOp::Build(BinaryType op, ExpressionPtr left, ExpressionPtr right)
{
auto node = std::make_shared<BinaryOp>();
node->op = op;
node->left = std::move(left);
node->right = std::move(right);
return node;
}
inline Branch::Branch() :
Statement(NodeType::Branch)
{
}
inline std::shared_ptr<Branch> Branch::Build(ExpressionPtr condition, StatementPtr trueStatement, StatementPtr falseStatement)
{
auto node = std::make_shared<Branch>();
node->condStatements.emplace_back(ConditionalStatement{ std::move(condition), std::move(trueStatement) });
node->elseStatement = std::move(falseStatement);
return node;
}
inline std::shared_ptr<Branch> Branch::Build(std::vector<ConditionalStatement> statements, StatementPtr elseStatement)
{
auto node = std::make_shared<Branch>();
node->condStatements = std::move(statements);
node->elseStatement = std::move(elseStatement);
return node;
}
inline Cast::Cast() :
Expression(NodeType::Cast)
{
}
inline std::shared_ptr<Cast> Cast::Build(BasicType castTo, ExpressionPtr first, ExpressionPtr second, ExpressionPtr third, ExpressionPtr fourth)
{
auto node = std::make_shared<Cast>();
node->exprType = castTo;
node->expressions = { {first, second, third, fourth} };
return node;
}
inline std::shared_ptr<Cast> Cast::Build(BasicType castTo, ExpressionPtr* Expressions, std::size_t expressionCount)
{
auto node = std::make_shared<Cast>();
node->exprType = castTo;
for (std::size_t i = 0; i < expressionCount; ++i)
node->expressions[i] = Expressions[i];
return node;
}
inline ConditionalExpression::ConditionalExpression() :
Expression(NodeType::ConditionalExpression)
{
}
inline std::shared_ptr<ConditionalExpression> ShaderNodes::ConditionalExpression::Build(std::string condition, ExpressionPtr truePath, ExpressionPtr falsePath)
{
auto node = std::make_shared<ConditionalExpression>();
node->conditionName = std::move(condition);
node->falsePath = std::move(falsePath);
node->truePath = std::move(truePath);
return node;
}
inline Constant::Constant() :
Expression(NodeType::Constant)
{
}
template<typename T>
std::shared_ptr<Constant> Nz::ShaderNodes::Constant::Build(const T& value)
{
auto node = std::make_shared<Constant>();
node->value = value;
return node;
}
inline SwizzleOp::SwizzleOp() :
Expression(NodeType::SwizzleOp)
{
}
inline std::shared_ptr<SwizzleOp> SwizzleOp::Build(ExpressionPtr expressionPtr, SwizzleComponent swizzleComponent)
{
return Build(std::move(expressionPtr), { swizzleComponent });
}
inline std::shared_ptr<SwizzleOp> SwizzleOp::Build(ExpressionPtr expressionPtr, std::initializer_list<SwizzleComponent> swizzleComponents)
{
auto node = std::make_shared<SwizzleOp>();
node->componentCount = swizzleComponents.size();
node->expression = std::move(expressionPtr);
std::copy(swizzleComponents.begin(), swizzleComponents.end(), node->components.begin());
return node;
}
inline std::shared_ptr<SwizzleOp> SwizzleOp::Build(ExpressionPtr expressionPtr, const SwizzleComponent* components, std::size_t componentCount)
{
auto node = std::make_shared<SwizzleOp>();
assert(componentCount < node->components.size());
node->componentCount = componentCount;
node->expression = std::move(expressionPtr);
std::copy(components, components + componentCount, node->components.begin());
return node;
}
inline Sample2D::Sample2D() :
Expression(NodeType::Sample2D)
{
}
inline std::shared_ptr<Sample2D> Sample2D::Build(ExpressionPtr samplerPtr, ExpressionPtr coordinatesPtr)
{
auto node = std::make_shared<Sample2D>();
node->coordinates = std::move(coordinatesPtr);
node->sampler = std::move(samplerPtr);
return node;
}
inline IntrinsicCall::IntrinsicCall() :
Expression(NodeType::IntrinsicCall)
{
}
inline std::shared_ptr<IntrinsicCall> IntrinsicCall::Build(IntrinsicType intrinsic, std::vector<ExpressionPtr> parameters)
{
auto node = std::make_shared<IntrinsicCall>();
node->intrinsic = intrinsic;
node->parameters = std::move(parameters);
return node;
}
}
#include <Nazara/Shader/DebugOff.hpp>

View File

@ -1,38 +0,0 @@
// Copyright (C) 2015 Jérôme Leclercq
// This file is part of the "Nazara Engine - Shader generator"
// For conditions of distribution and use, see copyright notice in Config.hpp
#pragma once
#ifndef NAZARA_SHADERVARVISITOR_HPP
#define NAZARA_SHADERVARVISITOR_HPP
#include <Nazara/Prerequisites.hpp>
#include <Nazara/Shader/Config.hpp>
#include <Nazara/Shader/ShaderVariables.hpp>
namespace Nz
{
class NAZARA_SHADER_API ShaderVarVisitor
{
public:
ShaderVarVisitor() = default;
ShaderVarVisitor(const ShaderVarVisitor&) = delete;
ShaderVarVisitor(ShaderVarVisitor&&) = delete;
virtual ~ShaderVarVisitor();
void Visit(const ShaderNodes::VariablePtr& node);
virtual void Visit(ShaderNodes::BuiltinVariable& var) = 0;
virtual void Visit(ShaderNodes::InputVariable& var) = 0;
virtual void Visit(ShaderNodes::LocalVariable& var) = 0;
virtual void Visit(ShaderNodes::OutputVariable& var) = 0;
virtual void Visit(ShaderNodes::ParameterVariable& var) = 0;
virtual void Visit(ShaderNodes::UniformVariable& var) = 0;
ShaderVarVisitor& operator=(const ShaderVarVisitor&) = delete;
ShaderVarVisitor& operator=(ShaderVarVisitor&&) = delete;
};
}
#endif

View File

@ -1,28 +0,0 @@
// Copyright (C) 2015 Jérôme Leclercq
// This file is part of the "Nazara Engine - Shader generator"
// For conditions of distribution and use, see copyright notice in Config.hpp
#pragma once
#ifndef NAZARA_SHADERVARVISITOREXCEPT_HPP
#define NAZARA_SHADERVARVISITOREXCEPT_HPP
#include <Nazara/Prerequisites.hpp>
#include <Nazara/Shader/ShaderVarVisitor.hpp>
namespace Nz
{
class NAZARA_SHADER_API ShaderVarVisitorExcept : public ShaderVarVisitor
{
public:
using ShaderVarVisitor::Visit;
void Visit(ShaderNodes::BuiltinVariable& var) override;
void Visit(ShaderNodes::InputVariable& var) override;
void Visit(ShaderNodes::LocalVariable& var) override;
void Visit(ShaderNodes::OutputVariable& var) override;
void Visit(ShaderNodes::ParameterVariable& var) override;
void Visit(ShaderNodes::UniformVariable& var) override;
};
}
#endif

View File

@ -1,128 +0,0 @@
// Copyright (C) 2020 Jérôme Leclercq
// This file is part of the "Nazara Engine - Shader generator"
// For conditions of distribution and use, see copyright notice in Config.hpp
#pragma once
#ifndef NAZARA_SHADER_VARIABLES_HPP
#define NAZARA_SHADER_VARIABLES_HPP
#include <Nazara/Prerequisites.hpp>
#include <Nazara/Math/Vector2.hpp>
#include <Nazara/Math/Vector3.hpp>
#include <Nazara/Math/Vector4.hpp>
#include <Nazara/Shader/Config.hpp>
#include <Nazara/Shader/ShaderExpressionType.hpp>
#include <array>
#include <memory>
#include <optional>
#include <string>
namespace Nz
{
class ShaderVarVisitor;
namespace ShaderNodes
{
struct Variable;
using VariablePtr = std::shared_ptr<Variable>;
struct NAZARA_SHADER_API Variable : std::enable_shared_from_this<Variable>
{
virtual ~Variable();
virtual VariableType GetType() const = 0;
virtual void Visit(ShaderVarVisitor& visitor) = 0;
ShaderExpressionType type;
};
struct BuiltinVariable;
using BuiltinVariablePtr = std::shared_ptr<BuiltinVariable>;
struct NAZARA_SHADER_API BuiltinVariable : public Variable
{
BuiltinEntry entry;
VariableType GetType() const override;
void Visit(ShaderVarVisitor& visitor) override;
static inline std::shared_ptr<BuiltinVariable> Build(BuiltinEntry entry, ShaderExpressionType varType);
};
struct NamedVariable;
using NamedVariablePtr = std::shared_ptr<NamedVariable>;
struct NAZARA_SHADER_API NamedVariable : public Variable
{
std::string name;
};
struct InputVariable;
using InputVariablePtr = std::shared_ptr<InputVariable>;
struct NAZARA_SHADER_API InputVariable : public NamedVariable
{
VariableType GetType() const override;
void Visit(ShaderVarVisitor& visitor) override;
static inline std::shared_ptr<InputVariable> Build(std::string varName, ShaderExpressionType varType);
};
struct LocalVariable;
using LocalVariablePtr = std::shared_ptr<LocalVariable>;
struct NAZARA_SHADER_API LocalVariable : public NamedVariable
{
VariableType GetType() const override;
void Visit(ShaderVarVisitor& visitor) override;
static inline std::shared_ptr<LocalVariable> Build(std::string varName, ShaderExpressionType varType);
};
struct OutputVariable;
using OutputVariablePtr = std::shared_ptr<OutputVariable>;
struct NAZARA_SHADER_API OutputVariable : public NamedVariable
{
VariableType GetType() const override;
void Visit(ShaderVarVisitor& visitor) override;
static inline std::shared_ptr<OutputVariable> Build(std::string varName, ShaderExpressionType varType);
};
struct ParameterVariable;
using ParameterVariablePtr = std::shared_ptr<ParameterVariable>;
struct NAZARA_SHADER_API ParameterVariable : public NamedVariable
{
VariableType GetType() const override;
void Visit(ShaderVarVisitor& visitor) override;
static inline std::shared_ptr<ParameterVariable> Build(std::string varName, ShaderExpressionType varType);
};
struct UniformVariable;
using UniformVariablePtr = std::shared_ptr<UniformVariable>;
struct NAZARA_SHADER_API UniformVariable : public NamedVariable
{
VariableType GetType() const override;
void Visit(ShaderVarVisitor& visitor) override;
static inline std::shared_ptr<UniformVariable> Build(std::string varName, ShaderExpressionType varType);
};
}
}
#include <Nazara/Shader/ShaderVariables.inl>
#endif

View File

@ -1,65 +0,0 @@
// Copyright (C) 2020 Jérôme Leclercq
// This file is part of the "Nazara Engine - Shader generator"
// For conditions of distribution and use, see copyright notice in Config.hpp
#include <Nazara/Shader/ShaderVariables.hpp>
#include <Nazara/Shader/Debug.hpp>
namespace Nz::ShaderNodes
{
inline std::shared_ptr<BuiltinVariable> BuiltinVariable::Build(BuiltinEntry variable, ShaderExpressionType varType)
{
auto node = std::make_shared<BuiltinVariable>();
node->entry = variable;
node->type = varType;
return node;
}
inline std::shared_ptr<InputVariable> InputVariable::Build(std::string varName, ShaderExpressionType varType)
{
auto node = std::make_shared<InputVariable>();
node->name = std::move(varName);
node->type = varType;
return node;
}
inline std::shared_ptr<LocalVariable> LocalVariable::Build(std::string varName, ShaderExpressionType varType)
{
auto node = std::make_shared<LocalVariable>();
node->name = std::move(varName);
node->type = varType;
return node;
}
inline std::shared_ptr<OutputVariable> OutputVariable::Build(std::string varName, ShaderExpressionType varType)
{
auto node = std::make_shared<OutputVariable>();
node->name = std::move(varName);
node->type = varType;
return node;
}
inline std::shared_ptr<ParameterVariable> ParameterVariable::Build(std::string varName, ShaderExpressionType varType)
{
auto node = std::make_shared<ParameterVariable>();
node->name = std::move(varName);
node->type = varType;
return node;
}
inline std::shared_ptr<UniformVariable> UniformVariable::Build(std::string varName, ShaderExpressionType varType)
{
auto node = std::make_shared<UniformVariable>();
node->name = std::move(varName);
node->type = varType;
return node;
}
}
#include <Nazara/Shader/DebugOff.hpp>

View File

@ -14,8 +14,6 @@
namespace Nz
{
class ShaderAst;
class NAZARA_SHADER_API ShaderWriter
{
public:

View File

@ -9,8 +9,8 @@
#include <Nazara/Prerequisites.hpp>
#include <Nazara/Shader/Config.hpp>
#include <Nazara/Shader/ShaderAstVisitorExcept.hpp>
#include <Nazara/Shader/ShaderVarVisitorExcept.hpp>
#include <Nazara/Shader/ShaderAstExpressionVisitorExcept.hpp>
#include <Nazara/Shader/ShaderAstStatementVisitorExcept.hpp>
#include <Nazara/Shader/SpirvBlock.hpp>
#include <vector>
@ -18,7 +18,7 @@ namespace Nz
{
class SpirvWriter;
class NAZARA_SHADER_API SpirvAstVisitor : public ShaderAstVisitorExcept
class NAZARA_SHADER_API SpirvAstVisitor : public ShaderAst::ExpressionVisitorExcept, public ShaderAst::StatementVisitorExcept
{
public:
inline SpirvAstVisitor(SpirvWriter& writer, std::vector<SpirvBlock>& blocks);
@ -26,27 +26,28 @@ namespace Nz
SpirvAstVisitor(SpirvAstVisitor&&) = delete;
~SpirvAstVisitor() = default;
UInt32 EvaluateExpression(const ShaderNodes::ExpressionPtr& expr);
UInt32 EvaluateExpression(ShaderAst::ExpressionPtr& expr);
using ShaderAstVisitorExcept::Visit;
void Visit(ShaderNodes::AccessMember& node) override;
void Visit(ShaderNodes::AssignOp& node) override;
void Visit(ShaderNodes::BinaryOp& node) override;
void Visit(ShaderNodes::Branch& node) override;
void Visit(ShaderNodes::Cast& node) override;
void Visit(ShaderNodes::ConditionalExpression& node) override;
void Visit(ShaderNodes::ConditionalStatement& node) override;
void Visit(ShaderNodes::Constant& node) override;
void Visit(ShaderNodes::DeclareVariable& node) override;
void Visit(ShaderNodes::Discard& node) override;
void Visit(ShaderNodes::ExpressionStatement& node) override;
void Visit(ShaderNodes::Identifier& node) override;
void Visit(ShaderNodes::IntrinsicCall& node) override;
void Visit(ShaderNodes::NoOp& node) override;
void Visit(ShaderNodes::ReturnStatement& node) override;
void Visit(ShaderNodes::Sample2D& node) override;
void Visit(ShaderNodes::StatementBlock& node) override;
void Visit(ShaderNodes::SwizzleOp& node) override;
using ExpressionVisitorExcept::Visit;
using StatementVisitorExcept::Visit;
void Visit(ShaderAst::AccessMemberExpression& node) override;
void Visit(ShaderAst::AssignExpression& node) override;
void Visit(ShaderAst::BinaryExpression& node) override;
void Visit(ShaderAst::BranchStatement& node) override;
void Visit(ShaderAst::CastExpression& node) override;
void Visit(ShaderAst::ConditionalExpression& node) override;
void Visit(ShaderAst::ConditionalStatement& node) override;
void Visit(ShaderAst::ConstantExpression& node) override;
void Visit(ShaderAst::DeclareVariableStatement& node) override;
void Visit(ShaderAst::DiscardStatement& node) override;
void Visit(ShaderAst::ExpressionStatement& node) override;
void Visit(ShaderAst::IdentifierExpression& node) override;
void Visit(ShaderAst::IntrinsicExpression& node) override;
void Visit(ShaderAst::MultiStatement& node) override;
void Visit(ShaderAst::NoOpStatement& node) override;
void Visit(ShaderAst::ReturnStatement& node) override;
void Visit(ShaderAst::SwizzleExpression& node) override;
SpirvAstVisitor& operator=(const SpirvAstVisitor&) = delete;
SpirvAstVisitor& operator=(SpirvAstVisitor&&) = delete;

View File

@ -10,7 +10,7 @@
#include <Nazara/Prerequisites.hpp>
#include <Nazara/Shader/ShaderConstantValue.hpp>
#include <Nazara/Shader/ShaderEnums.hpp>
#include <Nazara/Shader/ShaderExpressionType.hpp>
#include <Nazara/Shader/ShaderAstTypes.hpp>
#include <Nazara/Shader/SpirvData.hpp>
#include <memory>
#include <optional>
@ -20,7 +20,6 @@
namespace Nz
{
class ShaderAst;
class SpirvSection;
class NAZARA_SHADER_API SpirvConstantCache
@ -173,10 +172,10 @@ namespace Nz
SpirvConstantCache& operator=(SpirvConstantCache&& cache) noexcept;
static ConstantPtr BuildConstant(const ShaderConstantValue& value);
static TypePtr BuildPointerType(const ShaderNodes::BasicType& type, SpirvStorageClass storageClass);
static TypePtr BuildPointerType(const ShaderAst& shader, const ShaderExpressionType& type, SpirvStorageClass storageClass);
static TypePtr BuildType(const ShaderNodes::BasicType& type);
static TypePtr BuildType(const ShaderAst& shader, const ShaderExpressionType& type);
static TypePtr BuildPointerType(const ShaderAst::BasicType& type, SpirvStorageClass storageClass);
static TypePtr BuildPointerType(const ShaderAst::ShaderExpressionType& type, SpirvStorageClass storageClass);
static TypePtr BuildType(const ShaderAst::BasicType& type);
static TypePtr BuildType(const ShaderAst::ShaderExpressionType& type);
private:
struct DepRegisterer;

View File

@ -9,8 +9,7 @@
#include <Nazara/Prerequisites.hpp>
#include <Nazara/Shader/Config.hpp>
#include <Nazara/Shader/ShaderAstVisitorExcept.hpp>
#include <Nazara/Shader/ShaderVarVisitorExcept.hpp>
#include <Nazara/Shader/ShaderAstExpressionVisitorExcept.hpp>
#include <Nazara/Shader/SpirvData.hpp>
#include <vector>
@ -19,7 +18,7 @@ namespace Nz
class SpirvBlock;
class SpirvWriter;
class NAZARA_SHADER_API SpirvExpressionLoad : public ShaderAstVisitorExcept, public ShaderVarVisitorExcept
class NAZARA_SHADER_API SpirvExpressionLoad : public ShaderAst::ExpressionVisitorExcept
{
public:
inline SpirvExpressionLoad(SpirvWriter& writer, SpirvBlock& block);
@ -27,17 +26,11 @@ namespace Nz
SpirvExpressionLoad(SpirvExpressionLoad&&) = delete;
~SpirvExpressionLoad() = default;
UInt32 Evaluate(ShaderNodes::Expression& node);
UInt32 Evaluate(ShaderAst::Expression& node);
using ShaderAstVisitor::Visit;
void Visit(ShaderNodes::AccessMember& node) override;
void Visit(ShaderNodes::Identifier& node) override;
using ShaderVarVisitor::Visit;
void Visit(ShaderNodes::InputVariable& var) override;
void Visit(ShaderNodes::LocalVariable& var) override;
void Visit(ShaderNodes::ParameterVariable& var) override;
void Visit(ShaderNodes::UniformVariable& var) override;
using ExpressionVisitorExcept::Visit;
//void Visit(ShaderAst::AccessMemberExpression& node) override;
void Visit(ShaderAst::IdentifierExpression& node) override;
SpirvExpressionLoad& operator=(const SpirvExpressionLoad&) = delete;
SpirvExpressionLoad& operator=(SpirvExpressionLoad&&) = delete;

View File

@ -9,8 +9,7 @@
#include <Nazara/Prerequisites.hpp>
#include <Nazara/Shader/Config.hpp>
#include <Nazara/Shader/ShaderAstVisitorExcept.hpp>
#include <Nazara/Shader/ShaderVarVisitorExcept.hpp>
#include <Nazara/Shader/ShaderAstExpressionVisitorExcept.hpp>
#include <Nazara/Shader/SpirvData.hpp>
namespace Nz
@ -18,7 +17,7 @@ namespace Nz
class SpirvBlock;
class SpirvWriter;
class NAZARA_SHADER_API SpirvExpressionStore : public ShaderAstVisitorExcept, public ShaderVarVisitorExcept
class NAZARA_SHADER_API SpirvExpressionStore : public ShaderAst::ExpressionVisitorExcept
{
public:
inline SpirvExpressionStore(SpirvWriter& writer, SpirvBlock& block);
@ -26,17 +25,12 @@ namespace Nz
SpirvExpressionStore(SpirvExpressionStore&&) = delete;
~SpirvExpressionStore() = default;
void Store(const ShaderNodes::ExpressionPtr& node, UInt32 resultId);
void Store(ShaderAst::ExpressionPtr& node, UInt32 resultId);
using ShaderAstVisitorExcept::Visit;
void Visit(ShaderNodes::AccessMember& node) override;
void Visit(ShaderNodes::Identifier& node) override;
void Visit(ShaderNodes::SwizzleOp& node) override;
using ShaderVarVisitorExcept::Visit;
void Visit(ShaderNodes::BuiltinVariable& var) override;
void Visit(ShaderNodes::LocalVariable& var) override;
void Visit(ShaderNodes::OutputVariable& var) override;
using ExpressionVisitorExcept::Visit;
//void Visit(ShaderAst::AccessMemberExpression& node) override;
void Visit(ShaderAst::IdentifierExpression& node) override;
void Visit(ShaderAst::SwizzleExpression& node) override;
SpirvExpressionStore& operator=(const SpirvExpressionStore&) = delete;
SpirvExpressionStore& operator=(SpirvExpressionStore&&) = delete;

View File

@ -9,10 +9,8 @@
#include <Nazara/Prerequisites.hpp>
#include <Nazara/Shader/Config.hpp>
#include <Nazara/Shader/ShaderAst.hpp>
#include <Nazara/Shader/ShaderAstVisitor.hpp>
#include <Nazara/Shader/ShaderAstCache.hpp>
#include <Nazara/Shader/ShaderConstantValue.hpp>
#include <Nazara/Shader/ShaderVarVisitor.hpp>
#include <Nazara/Shader/ShaderWriter.hpp>
#include <Nazara/Shader/SpirvConstantCache.hpp>
#include <string>
@ -39,7 +37,7 @@ namespace Nz
SpirvWriter(SpirvWriter&&) = delete;
~SpirvWriter() = default;
std::vector<UInt32> Generate(const ShaderAst& shader, const States& conditions = {});
std::vector<UInt32> Generate(ShaderAst::StatementPtr& shader, const States& conditions = {});
void SetEnv(Environment environment);
@ -51,22 +49,23 @@ namespace Nz
private:
struct ExtVar;
struct FunctionParameter;
struct OnlyCache {};
UInt32 AllocateResultId();
void AppendHeader();
SpirvConstantCache::Function BuildFunctionType(ShaderExpressionType retType, const std::vector<ShaderAst::FunctionParameter>& parameters);
SpirvConstantCache::Function BuildFunctionType(ShaderAst::ShaderExpressionType retType, const std::vector<FunctionParameter>& parameters);
UInt32 GetConstantId(const ShaderConstantValue& value) const;
UInt32 GetFunctionTypeId(ShaderExpressionType retType, const std::vector<ShaderAst::FunctionParameter>& parameters);
const ExtVar& GetBuiltinVariable(ShaderNodes::BuiltinEntry builtin) const;
UInt32 GetFunctionTypeId(ShaderAst::ShaderExpressionType retType, const std::vector<FunctionParameter>& parameters);
const ExtVar& GetBuiltinVariable(ShaderAst::BuiltinEntry builtin) const;
const ExtVar& GetInputVariable(const std::string& name) const;
const ExtVar& GetOutputVariable(const std::string& name) const;
const ExtVar& GetUniformVariable(const std::string& name) const;
UInt32 GetPointerTypeId(const ShaderExpressionType& type, SpirvStorageClass storageClass) const;
UInt32 GetTypeId(const ShaderExpressionType& type) const;
UInt32 GetPointerTypeId(const ShaderAst::ShaderExpressionType& type, SpirvStorageClass storageClass) const;
UInt32 GetTypeId(const ShaderAst::ShaderExpressionType& type) const;
inline bool IsConditionEnabled(const std::string& condition) const;
@ -82,9 +81,9 @@ namespace Nz
std::optional<UInt32> ReadVariable(const ExtVar& var, OnlyCache);
UInt32 RegisterConstant(const ShaderConstantValue& value);
UInt32 RegisterFunctionType(ShaderExpressionType retType, const std::vector<ShaderAst::FunctionParameter>& parameters);
UInt32 RegisterPointerType(ShaderExpressionType type, SpirvStorageClass storageClass);
UInt32 RegisterType(ShaderExpressionType type);
UInt32 RegisterFunctionType(ShaderAst::ShaderExpressionType retType, const std::vector<FunctionParameter>& parameters);
UInt32 RegisterPointerType(ShaderAst::ShaderExpressionType type, SpirvStorageClass storageClass);
UInt32 RegisterType(ShaderAst::ShaderExpressionType type);
void WriteLocalVariable(std::string name, UInt32 resultId);
@ -92,7 +91,7 @@ namespace Nz
struct Context
{
const ShaderAst* shader = nullptr;
ShaderAst::AstCache cache;
const States* states = nullptr;
std::vector<SpirvBlock> functionBlocks;
};
@ -105,6 +104,12 @@ namespace Nz
std::optional<UInt32> valueId;
};
struct FunctionParameter
{
std::string name;
ShaderAst::ShaderExpressionType type;
};
struct State;
Context m_context;

View File

@ -10,10 +10,11 @@ namespace Nz
{
inline bool SpirvWriter::IsConditionEnabled(const std::string& condition) const
{
std::size_t conditionIndex = m_context.shader->FindConditionByName(condition);
/*std::size_t conditionIndex = m_context.shader->FindConditionByName(condition);
assert(conditionIndex != ShaderAst::InvalidCondition);
return TestBit<Nz::UInt64>(m_context.states->enabledConditions, conditionIndex);
return TestBit<Nz::UInt64>(m_context.states->enabledConditions, conditionIndex);*/
return false;
}
}

View File

@ -167,8 +167,8 @@ namespace Nz
auto& fragmentShader = settings.shaders[UnderlyingCast(ShaderStageType::Fragment)];
auto& vertexShader = settings.shaders[UnderlyingCast(ShaderStageType::Vertex)];
fragmentShader = std::make_shared<UberShader>(UnserializeShader(r_fragmentShader, sizeof(r_fragmentShader)));
vertexShader = std::make_shared<UberShader>(UnserializeShader(r_vertexShader, sizeof(r_vertexShader)));
fragmentShader = std::make_shared<UberShader>(ShaderAst::UnserializeShader(r_fragmentShader, sizeof(r_fragmentShader)));
vertexShader = std::make_shared<UberShader>(ShaderAst::UnserializeShader(r_vertexShader, sizeof(r_vertexShader)));
// Conditions

View File

@ -5,17 +5,17 @@
#include <Nazara/Graphics/UberShader.hpp>
#include <Nazara/Graphics/Graphics.hpp>
#include <Nazara/Renderer/RenderDevice.hpp>
#include <Nazara/Shader/ShaderAst.hpp>
#include <limits>
#include <stdexcept>
#include <Nazara/Graphics/Debug.hpp>
namespace Nz
{
UberShader::UberShader(ShaderAst shaderAst) :
UberShader::UberShader(ShaderAst::StatementPtr shaderAst) :
m_shaderAst(std::move(shaderAst))
{
std::size_t conditionCount = m_shaderAst.GetConditionCount();
//std::size_t conditionCount = m_shaderAst.GetConditionCount();
std::size_t conditionCount = 0;
if (conditionCount >= 64)
throw std::runtime_error("Too many conditions");
@ -27,10 +27,10 @@ namespace Nz
UInt64 UberShader::GetConditionFlagByName(const std::string_view& condition) const
{
std::size_t conditionIndex = m_shaderAst.FindConditionByName(condition);
/*std::size_t conditionIndex = m_shaderAst.FindConditionByName(condition);
if (conditionIndex != ShaderAst::InvalidCondition)
return SetBit<UInt64>(0, conditionIndex);
else
else*/
return 0;
}

View File

@ -8,6 +8,7 @@
#include <Nazara/Math/Algorithm.hpp>
#include <Nazara/Shader/ShaderBuilder.hpp>
#include <Nazara/Shader/ShaderAstCloner.hpp>
#include <Nazara/Shader/ShaderAstExpressionType.hpp>
#include <Nazara/Shader/ShaderAstUtils.hpp>
#include <Nazara/Shader/ShaderAstValidator.hpp>
#include <optional>
@ -20,64 +21,67 @@ namespace Nz
{
static const char* flipYUniformName = "_NzFlipValue";
struct AstAdapter : ShaderAstCloner
struct AstAdapter : ShaderAst::AstCloner
{
void Visit(ShaderNodes::AssignOp& node) override
void Visit(ShaderAst::AssignExpression& node) override
{
if (!flipYPosition)
return AstCloner::Visit(node);
if (node.left->GetType() != ShaderAst::NodeType::IdentifierExpression)
return AstCloner::Visit(node);
/*
FIXME:
const auto& identifier = static_cast<const ShaderAst::Identifier&>(*node.left);
if (identifier.var->GetType() != ShaderAst::VariableType::BuiltinVariable)
return ShaderAstCloner::Visit(node);
if (node.left->GetType() != ShaderNodes::NodeType::Identifier)
const auto& builtinVar = static_cast<const ShaderAst::BuiltinVariable&>(*identifier.var);
if (builtinVar.entry != ShaderAst::BuiltinEntry::VertexPosition)
return ShaderAstCloner::Visit(node);
const auto& identifier = static_cast<const ShaderNodes::Identifier&>(*node.left);
if (identifier.var->GetType() != ShaderNodes::VariableType::BuiltinVariable)
return ShaderAstCloner::Visit(node);
const auto& builtinVar = static_cast<const ShaderNodes::BuiltinVariable&>(*identifier.var);
if (builtinVar.entry != ShaderNodes::BuiltinEntry::VertexPosition)
return ShaderAstCloner::Visit(node);
auto flipVar = ShaderBuilder::Uniform(flipYUniformName, ShaderNodes::BasicType::Float1);
auto flipVar = ShaderBuilder::Uniform(flipYUniformName, ShaderAst::BasicType::Float1);
auto oneConstant = ShaderBuilder::Constant(1.f);
auto fixYValue = ShaderBuilder::Cast<ShaderNodes::BasicType::Float4>(oneConstant, ShaderBuilder::Identifier(flipVar), oneConstant, oneConstant);
auto fixYValue = ShaderBuilder::Cast<ShaderAst::BasicType::Float4>(oneConstant, ShaderBuilder::Identifier(flipVar), oneConstant, oneConstant);
auto mulFix = ShaderBuilder::Multiply(CloneExpression(node.right), fixYValue);
PushExpression(ShaderNodes::AssignOp::Build(node.op, CloneExpression(node.left), mulFix));
PushExpression(ShaderAst::AssignOp::Build(node.op, CloneExpression(node.left), mulFix));*/
}
bool flipYPosition = false;
};
}
struct GlslWriter::State
{
const States* states = nullptr;
ShaderAst::AstCache cache;
std::stringstream stream;
unsigned int indentLevel = 0;
};
GlslWriter::GlslWriter() :
m_currentState(nullptr)
{
}
std::string GlslWriter::Generate(const ShaderAst& inputShader, const States& conditions)
std::string GlslWriter::Generate(ShaderAst::StatementPtr& shader, const States& conditions)
{
const ShaderAst* selectedShader = &inputShader;
/*const ShaderAst* selectedShader = &inputShader;
std::optional<ShaderAst> modifiedShader;
if (inputShader.GetStage() == ShaderStageType::Vertex && m_environment.flipYPosition)
{
modifiedShader.emplace(inputShader);
modifiedShader->AddUniform(flipYUniformName, ShaderNodes::BasicType::Float1);
modifiedShader->AddUniform(flipYUniformName, ShaderAst::BasicType::Float1);
selectedShader = &modifiedShader.value();
}
const ShaderAst& shader = *selectedShader;
std::string error;
if (!ValidateShader(shader, &error))
throw std::runtime_error("Invalid shader AST: " + error);
m_context.states = &conditions;
m_context.shader = &shader;
}*/
State state;
m_currentState = &state;
CallOnExit onExit([this]()
@ -85,6 +89,10 @@ namespace Nz
m_currentState = nullptr;
});
std::string error;
if (!ShaderAst::ValidateAst(shader, &error, &state.cache))
throw std::runtime_error("Invalid shader AST: " + error);
unsigned int glslVersion;
if (m_environment.glES)
{
@ -165,52 +173,7 @@ namespace Nz
AppendLine();
}
// Structures
/*if (shader.GetStructCount() > 0)
{
AppendCommentSection("Structures");
for (const auto& s : shader.GetStructs())
{
Append("struct ");
AppendLine(s.name);
AppendLine("{");
for (const auto& m : s.members)
{
Append("\t");
Append(m.type);
Append(" ");
Append(m.name);
AppendLine(";");
}
AppendLine("};");
AppendLine();
}
}*/
// Global variables (uniforms, input and outputs)
const char* inKeyword = (glslVersion >= 130) ? "in" : "varying";
const char* outKeyword = (glslVersion >= 130) ? "out" : "varying";
DeclareVariables(shader, shader.GetUniforms(), "uniform", "Uniforms");
DeclareVariables(shader, shader.GetInputs(), inKeyword, "Inputs");
DeclareVariables(shader, shader.GetOutputs(), outKeyword, "Outputs");
std::size_t functionCount = shader.GetFunctionCount();
if (functionCount > 1)
{
AppendCommentSection("Prototypes");
for (const auto& func : shader.GetFunctions())
{
if (func.name != "main")
{
AppendFunctionPrototype(func);
AppendLine(";");
}
}
}
for (const auto& func : shader.GetFunctions())
AppendFunction(func);
shader->Visit(*this);
return state.stream.str();
}
@ -225,7 +188,7 @@ namespace Nz
return flipYUniformName;
}
void GlslWriter::Append(ShaderExpressionType type)
void GlslWriter::Append(ShaderAst::ShaderExpressionType type)
{
std::visit([&](auto&& arg)
{
@ -233,49 +196,57 @@ namespace Nz
}, type);
}
void GlslWriter::Append(ShaderNodes::BuiltinEntry builtin)
void GlslWriter::Append(ShaderAst::BuiltinEntry builtin)
{
switch (builtin)
{
case ShaderNodes::BuiltinEntry::VertexPosition:
case ShaderAst::BuiltinEntry::VertexPosition:
Append("gl_Position");
break;
}
}
void GlslWriter::Append(ShaderNodes::BasicType type)
void GlslWriter::Append(ShaderAst::BasicType type)
{
switch (type)
{
case ShaderNodes::BasicType::Boolean: return Append("bool");
case ShaderNodes::BasicType::Float1: return Append("float");
case ShaderNodes::BasicType::Float2: return Append("vec2");
case ShaderNodes::BasicType::Float3: return Append("vec3");
case ShaderNodes::BasicType::Float4: return Append("vec4");
case ShaderNodes::BasicType::Int1: return Append("int");
case ShaderNodes::BasicType::Int2: return Append("ivec2");
case ShaderNodes::BasicType::Int3: return Append("ivec3");
case ShaderNodes::BasicType::Int4: return Append("ivec4");
case ShaderNodes::BasicType::Mat4x4: return Append("mat4");
case ShaderNodes::BasicType::Sampler2D: return Append("sampler2D");
case ShaderNodes::BasicType::UInt1: return Append("uint");
case ShaderNodes::BasicType::UInt2: return Append("uvec2");
case ShaderNodes::BasicType::UInt3: return Append("uvec3");
case ShaderNodes::BasicType::UInt4: return Append("uvec4");
case ShaderNodes::BasicType::Void: return Append("void");
case ShaderAst::BasicType::Boolean: return Append("bool");
case ShaderAst::BasicType::Float1: return Append("float");
case ShaderAst::BasicType::Float2: return Append("vec2");
case ShaderAst::BasicType::Float3: return Append("vec3");
case ShaderAst::BasicType::Float4: return Append("vec4");
case ShaderAst::BasicType::Int1: return Append("int");
case ShaderAst::BasicType::Int2: return Append("ivec2");
case ShaderAst::BasicType::Int3: return Append("ivec3");
case ShaderAst::BasicType::Int4: return Append("ivec4");
case ShaderAst::BasicType::Mat4x4: return Append("mat4");
case ShaderAst::BasicType::Sampler2D: return Append("sampler2D");
case ShaderAst::BasicType::UInt1: return Append("uint");
case ShaderAst::BasicType::UInt2: return Append("uvec2");
case ShaderAst::BasicType::UInt3: return Append("uvec3");
case ShaderAst::BasicType::UInt4: return Append("uvec4");
case ShaderAst::BasicType::Void: return Append("void");
}
}
void GlslWriter::Append(ShaderNodes::MemoryLayout layout)
void GlslWriter::Append(ShaderAst::MemoryLayout layout)
{
switch (layout)
{
case ShaderNodes::MemoryLayout::Std140:
case ShaderAst::MemoryLayout::Std140:
Append("std140");
break;
}
}
template<typename T>
void GlslWriter::Append(const T& param)
{
NazaraAssert(m_currentState, "This function should only be called while processing an AST");
m_currentState->stream << param;
}
void GlslWriter::AppendCommentSection(const std::string& section)
{
NazaraAssert(m_currentState, "This function should only be called while processing an AST");
@ -285,67 +256,24 @@ namespace Nz
AppendLine();
}
void GlslWriter::AppendField(const std::string& structName, std::size_t* memberIndex, std::size_t remainingMembers)
void GlslWriter::AppendField(std::size_t scopeId, const std::string& structName, const std::string* memberIdentifier, std::size_t remainingMembers)
{
const auto& structs = m_context.shader->GetStructs();
auto it = std::find_if(structs.begin(), structs.end(), [&](const auto& s) { return s.name == structName; });
assert(it != structs.end());
const ShaderAst::Struct& s = *it;
assert(*memberIndex < s.members.size());
const auto& member = s.members[*memberIndex];
Append(".");
Append(member.name);
Append(memberIdentifier[0]);
const ShaderAst::AstCache::Identifier* identifier = m_currentState->cache.FindIdentifier(scopeId, structName);
assert(identifier);
assert(std::holds_alternative<ShaderAst::StructDescription>(identifier->value));
const auto& s = std::get<ShaderAst::StructDescription>(identifier->value);
auto memberIt = std::find_if(s.members.begin(), s.members.begin(), [&](const auto& field) { return field.name == memberIdentifier[0]; });
assert(memberIt != s.members.end());
const auto& member = *memberIt;
if (remainingMembers > 1)
{
assert(IsStructType(member.type));
AppendField(std::get<std::string>(member.type), memberIndex + 1, remainingMembers - 1);
}
}
void GlslWriter::AppendFunction(const ShaderAst::Function& func)
{
NazaraAssert(!m_context.currentFunction, "A function is already being processed");
NazaraAssert(m_currentState, "This function should only be called while processing an AST");
AppendFunctionPrototype(func);
m_context.currentFunction = &func;
CallOnExit onExit([this] ()
{
m_context.currentFunction = nullptr;
});
EnterScope();
{
AstAdapter adapter;
adapter.flipYPosition = m_environment.flipYPosition;
Visit(adapter.Clone(func.statement));
}
LeaveScope();
}
void GlslWriter::AppendFunctionPrototype(const ShaderAst::Function& func)
{
Append(func.returnType);
Append(" ");
Append(func.name);
Append("(");
for (std::size_t i = 0; i < func.parameters.size(); ++i)
{
if (i != 0)
Append(", ");
Append(func.parameters[i].type);
Append(" ");
Append(func.parameters[i].name);
}
Append(")\n");
AppendField(scopeId, std::get<std::string>(member.type), memberIdentifier + 1, remainingMembers - 1);
}
void GlslWriter::AppendLine(const std::string& txt)
@ -372,44 +300,46 @@ namespace Nz
AppendLine("}");
}
void GlslWriter::Visit(ShaderNodes::ExpressionPtr& expr, bool encloseIfRequired)
void GlslWriter::Visit(ShaderAst::ExpressionPtr& expr, bool encloseIfRequired)
{
bool enclose = encloseIfRequired && (GetExpressionCategory(expr) != ShaderNodes::ExpressionCategory::LValue);
bool enclose = encloseIfRequired && (GetExpressionCategory(*expr) != ShaderAst::ExpressionCategory::LValue);
if (enclose)
Append("(");
ShaderAstVisitor::Visit(expr);
expr->Visit(*this);
if (enclose)
Append(")");
}
void GlslWriter::Visit(ShaderNodes::AccessMember& node)
void GlslWriter::Visit(ShaderAst::AccessMemberExpression& node)
{
Visit(node.structExpr, true);
const ShaderExpressionType& exprType = node.structExpr->GetExpressionType();
const ShaderAst::ShaderExpressionType& exprType = GetExpressionType(*node.structExpr, &m_currentState->cache);
assert(IsStructType(exprType));
AppendField(std::get<std::string>(exprType), node.memberIndices.data(), node.memberIndices.size());
std::size_t scopeId = m_currentState->cache.GetScopeId(&node);
AppendField(scopeId, std::get<std::string>(exprType), node.memberIdentifiers.data(), node.memberIdentifiers.size());
}
void GlslWriter::Visit(ShaderNodes::AssignOp& node)
void GlslWriter::Visit(ShaderAst::AssignExpression& node)
{
Visit(node.left);
node.left->Visit(*this);
switch (node.op)
{
case ShaderNodes::AssignType::Simple:
case ShaderAst::AssignType::Simple:
Append(" = ");
break;
}
Visit(node.right);
node.left->Visit(*this);
}
void GlslWriter::Visit(ShaderNodes::Branch& node)
void GlslWriter::Visit(ShaderAst::BranchStatement& node)
{
bool first = true;
for (const auto& statement : node.condStatements)
@ -418,11 +348,11 @@ namespace Nz
Append("else ");
Append("if (");
Visit(statement.condition);
statement.condition->Visit(*this);
AppendLine(")");
EnterScope();
Visit(statement.statement);
statement.statement->Visit(*this);
LeaveScope();
first = false;
@ -433,41 +363,36 @@ namespace Nz
AppendLine("else");
EnterScope();
Visit(node.elseStatement);
node.elseStatement->Visit(*this);
LeaveScope();
}
}
void GlslWriter::Visit(ShaderNodes::BinaryOp& node)
void GlslWriter::Visit(ShaderAst::BinaryExpression& node)
{
Visit(node.left, true);
switch (node.op)
{
case ShaderNodes::BinaryType::Add: Append(" + "); break;
case ShaderNodes::BinaryType::Subtract: Append(" - "); break;
case ShaderNodes::BinaryType::Multiply: Append(" * "); break;
case ShaderNodes::BinaryType::Divide: Append(" / "); break;
case ShaderAst::BinaryType::Add: Append(" + "); break;
case ShaderAst::BinaryType::Subtract: Append(" - "); break;
case ShaderAst::BinaryType::Multiply: Append(" * "); break;
case ShaderAst::BinaryType::Divide: Append(" / "); break;
case ShaderNodes::BinaryType::CompEq: Append(" == "); break;
case ShaderNodes::BinaryType::CompGe: Append(" >= "); break;
case ShaderNodes::BinaryType::CompGt: Append(" > "); break;
case ShaderNodes::BinaryType::CompLe: Append(" <= "); break;
case ShaderNodes::BinaryType::CompLt: Append(" < "); break;
case ShaderNodes::BinaryType::CompNe: Append(" != "); break;
case ShaderAst::BinaryType::CompEq: Append(" == "); break;
case ShaderAst::BinaryType::CompGe: Append(" >= "); break;
case ShaderAst::BinaryType::CompGt: Append(" > "); break;
case ShaderAst::BinaryType::CompLe: Append(" <= "); break;
case ShaderAst::BinaryType::CompLt: Append(" < "); break;
case ShaderAst::BinaryType::CompNe: Append(" != "); break;
}
Visit(node.right, true);
}
void GlslWriter::Visit(ShaderNodes::BuiltinVariable& var)
void GlslWriter::Visit(ShaderAst::CastExpression& node)
{
Append(var.entry);
}
void GlslWriter::Visit(ShaderNodes::Cast& node)
{
Append(node.exprType);
Append(node.targetType);
Append("(");
bool first = true;
@ -479,34 +404,34 @@ namespace Nz
if (!first)
m_currentState->stream << ", ";
Visit(exprPtr);
exprPtr->Visit(*this);
first = false;
}
Append(")");
}
void GlslWriter::Visit(ShaderNodes::ConditionalExpression& node)
void GlslWriter::Visit(ShaderAst::ConditionalExpression& node)
{
std::size_t conditionIndex = m_context.shader->FindConditionByName(node.conditionName);
/*std::size_t conditionIndex = m_context.shader->FindConditionByName(node.conditionName);
assert(conditionIndex != ShaderAst::InvalidCondition);
if (TestBit<Nz::UInt64>(m_context.states->enabledConditions, conditionIndex))
Visit(node.truePath);
else
Visit(node.falsePath);
Visit(node.falsePath);*/
}
void GlslWriter::Visit(ShaderNodes::ConditionalStatement& node)
void GlslWriter::Visit(ShaderAst::ConditionalStatement& node)
{
std::size_t conditionIndex = m_context.shader->FindConditionByName(node.conditionName);
/*std::size_t conditionIndex = m_context.shader->FindConditionByName(node.conditionName);
assert(conditionIndex != ShaderAst::InvalidCondition);
if (TestBit<Nz::UInt64>(m_context.states->enabledConditions, conditionIndex))
Visit(node.statement);
Visit(node.statement);*/
}
void GlslWriter::Visit(ShaderNodes::Constant& node)
void GlslWriter::Visit(ShaderAst::ConstantExpression& node)
{
std::visit([&](auto&& arg)
{
@ -530,54 +455,74 @@ namespace Nz
}, node.value);
}
void GlslWriter::Visit(ShaderNodes::DeclareVariable& node)
void GlslWriter::Visit(ShaderAst::DeclareFunctionStatement& node)
{
assert(node.variable->GetType() == ShaderNodes::VariableType::LocalVariable);
NazaraAssert(m_currentState, "This function should only be called while processing an AST");
const auto& localVar = static_cast<const ShaderNodes::LocalVariable&>(*node.variable);
Append(localVar.type);
Append(node.returnType);
Append(" ");
Append(localVar.name);
if (node.expression)
Append(node.name);
Append("(");
for (std::size_t i = 0; i < node.parameters.size(); ++i)
{
if (i != 0)
Append(", ");
Append(node.parameters[i].type);
Append(" ");
Append(node.parameters[i].name);
}
Append(")\n");
EnterScope();
{
AstAdapter adapter;
adapter.flipYPosition = m_environment.flipYPosition;
for (auto& statement : node.statements)
adapter.Clone(statement)->Visit(*this);
}
LeaveScope();
}
void GlslWriter::Visit(ShaderAst::DeclareVariableStatement& node)
{
Append(node.varType);
Append(" ");
Append(node.varName);
if (node.initialExpression)
{
Append(" = ");
Visit(node.expression);
node.initialExpression->Visit(*this);
}
AppendLine(";");
}
void GlslWriter::Visit(ShaderNodes::Discard& /*node*/)
void GlslWriter::Visit(ShaderAst::DiscardStatement& /*node*/)
{
Append("discard;");
}
void GlslWriter::Visit(ShaderNodes::ExpressionStatement& node)
void GlslWriter::Visit(ShaderAst::ExpressionStatement& node)
{
Visit(node.expression);
node.expression->Visit(*this);
Append(";");
}
void GlslWriter::Visit(ShaderNodes::Identifier& node)
void GlslWriter::Visit(ShaderAst::IdentifierExpression& node)
{
Visit(node.var);
Append(node.identifier);
}
void GlslWriter::Visit(ShaderNodes::InputVariable& var)
{
Append(var.name);
}
void GlslWriter::Visit(ShaderNodes::IntrinsicCall& node)
void GlslWriter::Visit(ShaderAst::IntrinsicExpression& node)
{
switch (node.intrinsic)
{
case ShaderNodes::IntrinsicType::CrossProduct:
case ShaderAst::IntrinsicType::CrossProduct:
Append("cross");
break;
case ShaderNodes::IntrinsicType::DotProduct:
case ShaderAst::IntrinsicType::DotProduct:
Append("dot");
break;
}
@ -588,67 +533,43 @@ namespace Nz
if (i != 0)
Append(", ");
Visit(node.parameters[i]);
node.parameters[i]->Visit(*this);
}
Append(")");
}
void GlslWriter::Visit(ShaderNodes::LocalVariable& var)
void GlslWriter::Visit(ShaderAst::MultiStatement& node)
{
Append(var.name);
bool first = true;
for (const ShaderAst::StatementPtr& statement : node.statements)
{
if (!first && statement->GetType() != ShaderAst::NodeType::NoOpStatement)
AppendLine();
statement->Visit(*this);
first = false;
}
}
void GlslWriter::Visit(ShaderNodes::NoOp& /*node*/)
void GlslWriter::Visit(ShaderAst::NoOpStatement& /*node*/)
{
/* nothing to do */
}
void GlslWriter::Visit(ShaderNodes::ParameterVariable& var)
{
Append(var.name);
}
void GlslWriter::Visit(ShaderNodes::ReturnStatement& node)
void GlslWriter::Visit(ShaderAst::ReturnStatement& node)
{
if (node.returnExpr)
{
Append("return ");
Visit(node.returnExpr);
node.returnExpr->Visit(*this);
Append(";");
}
else
Append("return;");
}
void GlslWriter::Visit(ShaderNodes::OutputVariable& var)
{
Append(var.name);
}
void GlslWriter::Visit(ShaderNodes::Sample2D& node)
{
Append("texture(");
Visit(node.sampler);
Append(", ");
Visit(node.coordinates);
Append(")");
}
void GlslWriter::Visit(ShaderNodes::StatementBlock& node)
{
bool first = true;
for (const ShaderNodes::StatementPtr& statement : node.statements)
{
if (!first && statement->GetType() != ShaderNodes::NodeType::NoOp)
AppendLine();
Visit(statement);
first = false;
}
}
void GlslWriter::Visit(ShaderNodes::SwizzleOp& node)
void GlslWriter::Visit(ShaderAst::SwizzleExpression& node)
{
Visit(node.expression, true);
Append(".");
@ -657,44 +578,39 @@ namespace Nz
{
switch (node.components[i])
{
case ShaderNodes::SwizzleComponent::First:
case ShaderAst::SwizzleComponent::First:
Append("x");
break;
case ShaderNodes::SwizzleComponent::Second:
case ShaderAst::SwizzleComponent::Second:
Append("y");
break;
case ShaderNodes::SwizzleComponent::Third:
case ShaderAst::SwizzleComponent::Third:
Append("z");
break;
case ShaderNodes::SwizzleComponent::Fourth:
case ShaderAst::SwizzleComponent::Fourth:
Append("w");
break;
}
}
}
void GlslWriter::Visit(ShaderNodes::UniformVariable& var)
bool GlslWriter::HasExplicitBinding(ShaderAst::StatementPtr& shader)
{
Append(var.name);
}
bool GlslWriter::HasExplicitBinding(const ShaderAst& shader)
{
for (const auto& uniform : shader.GetUniforms())
/*for (const auto& uniform : shader.GetUniforms())
{
if (uniform.bindingIndex.has_value())
return true;
}
}*/
return false;
}
bool GlslWriter::HasExplicitLocation(const ShaderAst& shader)
bool GlslWriter::HasExplicitLocation(ShaderAst::StatementPtr& shader)
{
for (const auto& input : shader.GetInputs())
/*for (const auto& input : shader.GetInputs())
{
if (input.locationIndex.has_value())
return true;
@ -704,7 +620,7 @@ namespace Nz
{
if (output.locationIndex.has_value())
return true;
}
}*/
return false;
}

View File

@ -1,56 +0,0 @@
// Copyright (C) 2020 Jérôme Leclercq
// This file is part of the "Nazara Engine - Shader generator"
// For conditions of distribution and use, see copyright notice in Config.hpp
#include <Nazara/Shader/ShaderAst.hpp>
#include <Nazara/Shader/Debug.hpp>
namespace Nz
{
void ShaderAst::AddCondition(std::string name)
{
auto& conditionEntry = m_conditions.emplace_back();
conditionEntry.name = std::move(name);
}
void ShaderAst::AddFunction(std::string name, ShaderNodes::StatementPtr statement, std::vector<FunctionParameter> parameters, ShaderExpressionType returnType)
{
auto& functionEntry = m_functions.emplace_back();
functionEntry.name = std::move(name);
functionEntry.parameters = std::move(parameters);
functionEntry.returnType = returnType;
functionEntry.statement = std::move(statement);
}
void ShaderAst::AddInput(std::string name, ShaderExpressionType type, std::optional<std::size_t> locationIndex)
{
auto& inputEntry = m_inputs.emplace_back();
inputEntry.name = std::move(name);
inputEntry.locationIndex = std::move(locationIndex);
inputEntry.type = std::move(type);
}
void ShaderAst::AddOutput(std::string name, ShaderExpressionType type, std::optional<std::size_t> locationIndex)
{
auto& outputEntry = m_outputs.emplace_back();
outputEntry.name = std::move(name);
outputEntry.locationIndex = std::move(locationIndex);
outputEntry.type = std::move(type);
}
void ShaderAst::AddStruct(std::string name, std::vector<StructMember> members)
{
auto& structEntry = m_structs.emplace_back();
structEntry.name = std::move(name);
structEntry.members = std::move(members);
}
void ShaderAst::AddUniform(std::string name, ShaderExpressionType type, std::optional<std::size_t> bindingIndex, std::optional<ShaderNodes::MemoryLayout> memoryLayout)
{
auto& uniformEntry = m_uniforms.emplace_back();
uniformEntry.bindingIndex = std::move(bindingIndex);
uniformEntry.memoryLayout = std::move(memoryLayout);
uniformEntry.name = std::move(name);
uniformEntry.type = std::move(type);
}
}

View File

@ -6,240 +6,257 @@
#include <stdexcept>
#include <Nazara/Shader/Debug.hpp>
namespace Nz
namespace Nz::ShaderAst
{
ShaderNodes::StatementPtr ShaderAstCloner::Clone(const ShaderNodes::StatementPtr& statement)
ExpressionPtr AstCloner::Clone(ExpressionPtr& expr)
{
ShaderAstVisitor::Visit(statement);
expr->Visit(*this);
if (!m_expressionStack.empty() || !m_variableStack.empty() || m_statementStack.size() != 1)
throw std::runtime_error("An error occurred during clone");
assert(m_statementStack.empty() && m_expressionStack.size() == 1);
return PopExpression();
}
StatementPtr AstCloner::Clone(StatementPtr& statement)
{
statement->Visit(*this);
assert(m_expressionStack.empty() && m_statementStack.size() == 1);
return PopStatement();
}
ShaderNodes::ExpressionPtr ShaderAstCloner::CloneExpression(const ShaderNodes::ExpressionPtr& expr)
ExpressionPtr AstCloner::CloneExpression(ExpressionPtr& expr)
{
if (!expr)
return nullptr;
ShaderAstVisitor::Visit(expr);
expr->Visit(*this);
return PopExpression();
}
ShaderNodes::StatementPtr ShaderAstCloner::CloneStatement(const ShaderNodes::StatementPtr& statement)
StatementPtr AstCloner::CloneStatement(StatementPtr& statement)
{
if (!statement)
return nullptr;
ShaderAstVisitor::Visit(statement);
statement->Visit(*this);
return PopStatement();
}
ShaderNodes::VariablePtr ShaderAstCloner::CloneVariable(const ShaderNodes::VariablePtr& variable)
void AstCloner::Visit(AccessMemberExpression& node)
{
if (!variable)
return nullptr;
auto clone = std::make_unique<AccessMemberExpression>();
clone->memberIdentifiers = node.memberIdentifiers;
clone->structExpr = CloneExpression(node.structExpr);
ShaderVarVisitor::Visit(variable);
return PopVariable();
PushExpression(std::move(clone));
}
void ShaderAstCloner::Visit(ShaderNodes::AccessMember& node)
void AstCloner::Visit(AssignExpression& node)
{
PushExpression(ShaderNodes::AccessMember::Build(CloneExpression(node.structExpr), node.memberIndices, node.exprType));
auto clone = std::make_unique<AssignExpression>();
clone->op = node.op;
clone->left = CloneExpression(node.left);
clone->right = CloneExpression(node.right);
PushExpression(std::move(clone));
}
void ShaderAstCloner::Visit(ShaderNodes::AssignOp& node)
void AstCloner::Visit(BinaryExpression& node)
{
PushExpression(ShaderNodes::AssignOp::Build(node.op, CloneExpression(node.left), CloneExpression(node.right)));
auto clone = std::make_unique<BinaryExpression>();
clone->op = node.op;
clone->left = CloneExpression(node.left);
clone->right = CloneExpression(node.right);
PushExpression(std::move(clone));
}
void ShaderAstCloner::Visit(ShaderNodes::BinaryOp& node)
void AstCloner::Visit(CastExpression& node)
{
PushExpression(ShaderNodes::BinaryOp::Build(node.op, CloneExpression(node.left), CloneExpression(node.right)));
}
auto clone = std::make_unique<CastExpression>();
clone->targetType = node.targetType;
void ShaderAstCloner::Visit(ShaderNodes::Branch& node)
{
std::vector<ShaderNodes::Branch::ConditionalStatement> condStatements;
condStatements.reserve(node.condStatements.size());
for (auto& cond : node.condStatements)
{
auto& condStatement = condStatements.emplace_back();
condStatement.condition = CloneExpression(cond.condition);
condStatement.statement = CloneStatement(cond.statement);
}
PushStatement(ShaderNodes::Branch::Build(std::move(condStatements), CloneStatement(node.elseStatement)));
}
void ShaderAstCloner::Visit(ShaderNodes::Cast& node)
{
std::size_t expressionCount = 0;
std::array<ShaderNodes::ExpressionPtr, 4> expressions;
for (auto& expr : node.expressions)
{
if (!expr)
break;
expressions[expressionCount] = CloneExpression(expr);
expressionCount++;
clone->expressions[expressionCount++] = CloneExpression(expr);
}
PushExpression(ShaderNodes::Cast::Build(node.exprType, expressions.data(), expressionCount));
PushExpression(std::move(clone));
}
void ShaderAstCloner::Visit(ShaderNodes::ConditionalExpression& node)
void AstCloner::Visit(ConditionalExpression& node)
{
PushExpression(ShaderNodes::ConditionalExpression::Build(node.conditionName, CloneExpression(node.truePath), CloneExpression(node.falsePath)));
auto clone = std::make_unique<ConditionalExpression>();
clone->conditionName = node.conditionName;
clone->falsePath = CloneExpression(node.falsePath);
clone->truePath = CloneExpression(node.truePath);
PushExpression(std::move(clone));
}
void ShaderAstCloner::Visit(ShaderNodes::ConditionalStatement& node)
void AstCloner::Visit(ConstantExpression& node)
{
PushStatement(ShaderNodes::ConditionalStatement::Build(node.conditionName, CloneStatement(node.statement)));
auto clone = std::make_unique<ConstantExpression>();
clone->value = node.value;
PushExpression(std::move(clone));
}
void ShaderAstCloner::Visit(ShaderNodes::Constant& node)
void AstCloner::Visit(IdentifierExpression& node)
{
PushExpression(ShaderNodes::Constant::Build(node.value));
auto clone = std::make_unique<IdentifierExpression>();
clone->identifier = node.identifier;
PushExpression(std::move(clone));
}
void ShaderAstCloner::Visit(ShaderNodes::DeclareVariable& node)
void AstCloner::Visit(IntrinsicExpression& node)
{
PushStatement(ShaderNodes::DeclareVariable::Build(CloneVariable(node.variable), CloneExpression(node.expression)));
}
void ShaderAstCloner::Visit(ShaderNodes::Discard& /*node*/)
{
PushStatement(ShaderNodes::Discard::Build());
}
void ShaderAstCloner::Visit(ShaderNodes::ExpressionStatement& node)
{
PushStatement(ShaderNodes::ExpressionStatement::Build(CloneExpression(node.expression)));
}
void ShaderAstCloner::Visit(ShaderNodes::Identifier& node)
{
PushExpression(ShaderNodes::Identifier::Build(CloneVariable(node.var)));
}
void ShaderAstCloner::Visit(ShaderNodes::IntrinsicCall& node)
{
std::vector<ShaderNodes::ExpressionPtr> parameters;
parameters.reserve(node.parameters.size());
auto clone = std::make_unique<IntrinsicExpression>();
clone->intrinsic = node.intrinsic;
clone->parameters.reserve(node.parameters.size());
for (auto& parameter : node.parameters)
parameters.push_back(CloneExpression(parameter));
clone->parameters.push_back(CloneExpression(parameter));
PushExpression(ShaderNodes::IntrinsicCall::Build(node.intrinsic, std::move(parameters)));
PushExpression(std::move(clone));
}
void ShaderAstCloner::Visit(ShaderNodes::NoOp& /*node*/)
void AstCloner::Visit(SwizzleExpression& node)
{
PushStatement(ShaderNodes::NoOp::Build());
auto clone = std::make_unique<SwizzleExpression>();
clone->componentCount = node.componentCount;
clone->components = node.components;
clone->expression = CloneExpression(node.expression);
PushExpression(std::move(clone));
}
void ShaderAstCloner::Visit(ShaderNodes::ReturnStatement& node)
void AstCloner::Visit(BranchStatement& node)
{
PushStatement(ShaderNodes::ReturnStatement::Build(CloneExpression(node.returnExpr)));
auto clone = std::make_unique<BranchStatement>();
clone->condStatements.reserve(node.condStatements.size());
for (auto& cond : node.condStatements)
{
auto& condStatement = clone->condStatements.emplace_back();
condStatement.condition = CloneExpression(cond.condition);
condStatement.statement = CloneStatement(cond.statement);
}
clone->elseStatement = CloneStatement(node.elseStatement);
PushStatement(std::move(clone));
}
void ShaderAstCloner::Visit(ShaderNodes::Sample2D& node)
void AstCloner::Visit(ConditionalStatement& node)
{
PushExpression(ShaderNodes::Sample2D::Build(CloneExpression(node.sampler), CloneExpression(node.coordinates)));
auto clone = std::make_unique<ConditionalStatement>();
clone->conditionName = node.conditionName;
clone->statement = CloneStatement(node.statement);
PushStatement(std::move(clone));
}
void ShaderAstCloner::Visit(ShaderNodes::StatementBlock& node)
void AstCloner::Visit(DeclareFunctionStatement& node)
{
std::vector<ShaderNodes::StatementPtr> statements;
statements.reserve(node.statements.size());
auto clone = std::make_unique<DeclareFunctionStatement>();
clone->name = node.name;
clone->parameters = node.parameters;
clone->returnType = node.returnType;
clone->statements.reserve(node.statements.size());
for (auto& statement : node.statements)
statements.push_back(CloneStatement(statement));
clone->statements.push_back(CloneStatement(statement));
PushStatement(ShaderNodes::StatementBlock::Build(std::move(statements)));
PushStatement(std::move(clone));
}
void ShaderAstCloner::Visit(ShaderNodes::SwizzleOp& node)
void AstCloner::Visit(DeclareStructStatement& node)
{
PushExpression(ShaderNodes::SwizzleOp::Build(CloneExpression(node.expression), node.components.data(), node.componentCount));
auto clone = std::make_unique<DeclareStructStatement>();
clone->description = node.description;
PushStatement(std::move(clone));
}
void ShaderAstCloner::Visit(ShaderNodes::BuiltinVariable& var)
void AstCloner::Visit(DeclareVariableStatement& node)
{
PushVariable(ShaderNodes::BuiltinVariable::Build(var.entry, var.type));
auto clone = std::make_unique<DeclareVariableStatement>();
clone->varName = node.varName;
clone->varType = node.varType;
clone->initialExpression = CloneExpression(node.initialExpression);
PushStatement(std::move(clone));
}
void ShaderAstCloner::Visit(ShaderNodes::InputVariable& var)
void AstCloner::Visit(DiscardStatement& /*node*/)
{
PushVariable(ShaderNodes::InputVariable::Build(var.name, var.type));
PushStatement(std::make_unique<DiscardStatement>());
}
void ShaderAstCloner::Visit(ShaderNodes::LocalVariable& var)
void AstCloner::Visit(ExpressionStatement& node)
{
PushVariable(ShaderNodes::LocalVariable::Build(var.name, var.type));
auto clone = std::make_unique<ExpressionStatement>();
clone->expression = CloneExpression(node.expression);
PushStatement(std::move(clone));
}
void ShaderAstCloner::Visit(ShaderNodes::OutputVariable& var)
void AstCloner::Visit(MultiStatement& node)
{
PushVariable(ShaderNodes::OutputVariable::Build(var.name, var.type));
auto clone = std::make_unique<MultiStatement>();
clone->statements.reserve(node.statements.size());
for (auto& statement : node.statements)
clone->statements.push_back(CloneStatement(statement));
PushStatement(std::move(clone));
}
void ShaderAstCloner::Visit(ShaderNodes::ParameterVariable& var)
void AstCloner::Visit(NoOpStatement& /*node*/)
{
PushVariable(ShaderNodes::ParameterVariable::Build(var.name, var.type));
PushStatement(std::make_unique<NoOpStatement>());
}
void ShaderAstCloner::Visit(ShaderNodes::UniformVariable& var)
void AstCloner::Visit(ReturnStatement& node)
{
PushVariable(ShaderNodes::UniformVariable::Build(var.name, var.type));
auto clone = std::make_unique<ReturnStatement>();
clone->returnExpr = CloneExpression(node.returnExpr);
PushStatement(std::move(clone));
}
void ShaderAstCloner::PushExpression(ShaderNodes::ExpressionPtr expression)
void AstCloner::PushExpression(ExpressionPtr expression)
{
m_expressionStack.emplace_back(std::move(expression));
}
void ShaderAstCloner::PushStatement(ShaderNodes::StatementPtr statement)
void AstCloner::PushStatement(StatementPtr statement)
{
m_statementStack.emplace_back(std::move(statement));
}
void ShaderAstCloner::PushVariable(ShaderNodes::VariablePtr variable)
{
m_variableStack.emplace_back(std::move(variable));
}
ShaderNodes::ExpressionPtr ShaderAstCloner::PopExpression()
ExpressionPtr AstCloner::PopExpression()
{
assert(!m_expressionStack.empty());
ShaderNodes::ExpressionPtr expr = std::move(m_expressionStack.back());
ExpressionPtr expr = std::move(m_expressionStack.back());
m_expressionStack.pop_back();
return expr;
}
ShaderNodes::StatementPtr ShaderAstCloner::PopStatement()
StatementPtr AstCloner::PopStatement()
{
assert(!m_statementStack.empty());
ShaderNodes::StatementPtr expr = std::move(m_statementStack.back());
StatementPtr expr = std::move(m_statementStack.back());
m_statementStack.pop_back();
return expr;
}
ShaderNodes::VariablePtr ShaderAstCloner::PopVariable()
{
assert(!m_variableStack.empty());
ShaderNodes::VariablePtr var = std::move(m_variableStack.back());
m_variableStack.pop_back();
return var;
}
}

View File

@ -0,0 +1,198 @@
// Copyright (C) 2020 Jérôme Leclercq
// This file is part of the "Nazara Engine - Shader generator"
// For conditions of distribution and use, see copyright notice in Config.hpp
#include <Nazara/Shader/ShaderAstExpressionType.hpp>
#include <Nazara/Shader/ShaderAstCache.hpp>
#include <optional>
#include <Nazara/Shader/Debug.hpp>
namespace Nz::ShaderAst
{
ShaderExpressionType ExpressionTypeVisitor::GetExpressionType(Expression& expression, AstCache* cache = nullptr)
{
m_cache = cache;
ShaderExpressionType type = GetExpressionTypeInternal(expression);
m_cache = nullptr;
return type;
}
ShaderExpressionType ExpressionTypeVisitor::GetExpressionTypeInternal(Expression& expression)
{
m_lastExpressionType.reset();
Visit(expression);
assert(m_lastExpressionType.has_value());
return std::move(*m_lastExpressionType);
}
void ExpressionTypeVisitor::Visit(Expression& expression)
{
if (m_cache)
{
auto it = m_cache->nodeExpressionType.find(&expression);
if (it != m_cache->nodeExpressionType.end())
{
m_lastExpressionType = it->second;
return;
}
}
expression.Visit(*this);
if (m_cache)
{
assert(m_lastExpressionType.has_value());
m_cache->nodeExpressionType.emplace(&expression, *m_lastExpressionType);
}
}
void ExpressionTypeVisitor::Visit(AccessMemberExpression& node)
{
throw std::runtime_error("unhandled accessmember expression");
}
void ExpressionTypeVisitor::Visit(AssignExpression& node)
{
Visit(*node.left);
}
void ExpressionTypeVisitor::Visit(BinaryExpression& node)
{
switch (node.op)
{
case BinaryType::Add:
case BinaryType::Subtract:
return Visit(*node.left);
case BinaryType::Divide:
case BinaryType::Multiply:
{
ShaderExpressionType leftExprType = GetExpressionTypeInternal(*node.left);
assert(IsBasicType(leftExprType));
ShaderExpressionType rightExprType = GetExpressionTypeInternal(*node.right);
assert(IsBasicType(rightExprType));
switch (std::get<BasicType>(leftExprType))
{
case BasicType::Boolean:
case BasicType::Float2:
case BasicType::Float3:
case BasicType::Float4:
case BasicType::Int2:
case BasicType::Int3:
case BasicType::Int4:
case BasicType::UInt2:
case BasicType::UInt3:
case BasicType::UInt4:
m_lastExpressionType = std::move(leftExprType);
break;
case BasicType::Float1:
case BasicType::Int1:
case BasicType::Mat4x4:
case BasicType::UInt1:
m_lastExpressionType = std::move(rightExprType);
break;
case BasicType::Sampler2D:
case BasicType::Void:
break;
}
break;
}
case BinaryType::CompEq:
case BinaryType::CompGe:
case BinaryType::CompGt:
case BinaryType::CompLe:
case BinaryType::CompLt:
case BinaryType::CompNe:
m_lastExpressionType = BasicType::Boolean;
break;
}
}
void ExpressionTypeVisitor::Visit(CastExpression& node)
{
m_lastExpressionType = node.targetType;
}
void ExpressionTypeVisitor::Visit(ConditionalExpression& node)
{
ShaderExpressionType leftExprType = GetExpressionTypeInternal(*node.truePath);
assert(leftExprType == GetExpressionTypeInternal(*node.falsePath));
m_lastExpressionType = std::move(leftExprType);
}
void ExpressionTypeVisitor::Visit(ConstantExpression& node)
{
m_lastExpressionType = std::visit([&](auto&& arg)
{
using T = std::decay_t<decltype(arg)>;
if constexpr (std::is_same_v<T, bool>)
return BasicType::Boolean;
else if constexpr (std::is_same_v<T, float>)
return BasicType::Float1;
else if constexpr (std::is_same_v<T, Int32>)
return BasicType::Int1;
else if constexpr (std::is_same_v<T, UInt32>)
return BasicType::Int1;
else if constexpr (std::is_same_v<T, Vector2f>)
return BasicType::Float2;
else if constexpr (std::is_same_v<T, Vector3f>)
return BasicType::Float3;
else if constexpr (std::is_same_v<T, Vector4f>)
return BasicType::Float4;
else if constexpr (std::is_same_v<T, Vector2i32>)
return BasicType::Int2;
else if constexpr (std::is_same_v<T, Vector3i32>)
return BasicType::Int3;
else if constexpr (std::is_same_v<T, Vector4i32>)
return BasicType::Int4;
else
static_assert(AlwaysFalse<T>::value, "non-exhaustive visitor");
}, node.value);
}
void ExpressionTypeVisitor::Visit(IdentifierExpression& node)
{
auto scopeIt = m_cache->scopeIdByNode.find(&node);
if (scopeIt == m_cache->scopeIdByNode.end())
throw std::runtime_error("internal error");
const AstCache::Identifier* identifier = m_cache->FindIdentifier(scopeIt->second, node.identifier);
if (!identifier || !std::holds_alternative<AstCache::Variable>(identifier->value))
throw std::runtime_error("internal error");
m_lastExpressionType = std::get<AstCache::Variable>(identifier->value).type;
}
void ExpressionTypeVisitor::Visit(IntrinsicExpression& node)
{
switch (node.intrinsic)
{
case IntrinsicType::CrossProduct:
Visit(*node.parameters.front());
break;
case IntrinsicType::DotProduct:
m_lastExpressionType = BasicType::Float1;
break;
}
}
void ExpressionTypeVisitor::Visit(SwizzleExpression& node)
{
const ShaderExpressionType& exprType = GetExpressionTypeInternal(*node.expression);
assert(IsBasicType(exprType));
m_lastExpressionType = static_cast<BasicType>(UnderlyingCast(GetComponentType(std::get<BasicType>(exprType))) + node.componentCount - 1);
}
}

View File

@ -2,15 +2,10 @@
// This file is part of the "Nazara Engine - Shader generator"
// For conditions of distribution and use, see copyright notice in Config.hpp
#include <Nazara/Shader/ShaderAstVisitor.hpp>
#include <Nazara/Shader/ShaderAstExpressionVisitor.hpp>
#include <Nazara/Shader/Debug.hpp>
namespace Nz
namespace Nz::ShaderAst
{
ShaderAstVisitor::~ShaderAstVisitor() = default;
void ShaderAstVisitor::Visit(const ShaderNodes::NodePtr& node)
{
node->Visit(*this);
}
AstExpressionVisitor::~AstExpressionVisitor() = default;
}

View File

@ -0,0 +1,15 @@
// Copyright (C) 2020 Jérôme Leclercq
// This file is part of the "Nazara Engine - Shader generator"
// For conditions of distribution and use, see copyright notice in Config.hpp
#include <Nazara/Shader/ShaderAstExpressionVisitorExcept.hpp>
#include <Nazara/Shader/Debug.hpp>
namespace Nz::ShaderAst
{
#define NAZARA_SHADERAST_EXPRESSION(Node) void ExpressionVisitorExcept::Visit(ShaderAst::Node& /*node*/) \
{ \
throw std::runtime_error("unexpected " #Node " node"); \
}
#include <Nazara/Shader/ShaderAstNodes.hpp>
}

View File

@ -3,16 +3,22 @@
// For conditions of distribution and use, see copyright notice in Config.hpp
#include <Nazara/Shader/ShaderAstOptimizer.hpp>
#include <Nazara/Shader/ShaderAst.hpp>
#include <Nazara/Shader/ShaderBuilder.hpp>
#include <Nazara/Shader/ShaderAstExpressionType.hpp>
#include <cassert>
#include <stdexcept>
#include <Nazara/Shader/Debug.hpp>
namespace Nz
namespace Nz::ShaderAst
{
namespace
{
template<typename T, typename U>
std::unique_ptr<T> static_unique_pointer_cast(std::unique_ptr<U>&& ptr)
{
return std::unique_ptr<T>(static_cast<T*>(ptr.release()));
}
template <typename T>
struct is_complete_helper
{
@ -29,14 +35,14 @@ namespace Nz
inline constexpr bool is_complete_v = is_complete<T>::value;
template<ShaderNodes::BinaryType Type, typename T1, typename T2>
template<BinaryType Type, typename T1, typename T2>
struct PropagateConstantType;
// CompEq
template<typename T1, typename T2>
struct CompEqBase
{
ShaderNodes::ExpressionPtr operator()(const T1& lhs, const T2& rhs)
ExpressionPtr operator()(const T1& lhs, const T2& rhs)
{
return ShaderBuilder::Constant(lhs == rhs);
}
@ -46,7 +52,7 @@ namespace Nz
struct CompEq;
template<typename T1, typename T2>
struct PropagateConstantType<ShaderNodes::BinaryType::CompEq, T1, T2>
struct PropagateConstantType<BinaryType::CompEq, T1, T2>
{
using Op = typename CompEq<T1, T2>;
};
@ -55,7 +61,7 @@ namespace Nz
template<typename T1, typename T2>
struct CompGeBase
{
ShaderNodes::ExpressionPtr operator()(const T1& lhs, const T2& rhs)
ExpressionPtr operator()(const T1& lhs, const T2& rhs)
{
return ShaderBuilder::Constant(lhs >= rhs);
}
@ -65,7 +71,7 @@ namespace Nz
struct CompGe;
template<typename T1, typename T2>
struct PropagateConstantType<ShaderNodes::BinaryType::CompGe, T1, T2>
struct PropagateConstantType<BinaryType::CompGe, T1, T2>
{
using Op = typename CompGe<T1, T2>;
};
@ -74,7 +80,7 @@ namespace Nz
template<typename T1, typename T2>
struct CompGtBase
{
ShaderNodes::ExpressionPtr operator()(const T1& lhs, const T2& rhs)
ExpressionPtr operator()(const T1& lhs, const T2& rhs)
{
return ShaderBuilder::Constant(lhs > rhs);
}
@ -84,7 +90,7 @@ namespace Nz
struct CompGt;
template<typename T1, typename T2>
struct PropagateConstantType<ShaderNodes::BinaryType::CompGt, T1, T2>
struct PropagateConstantType<BinaryType::CompGt, T1, T2>
{
using Op = typename CompGt<T1, T2>;
};
@ -93,7 +99,7 @@ namespace Nz
template<typename T1, typename T2>
struct CompLeBase
{
ShaderNodes::ExpressionPtr operator()(const T1& lhs, const T2& rhs)
ExpressionPtr operator()(const T1& lhs, const T2& rhs)
{
return ShaderBuilder::Constant(lhs <= rhs);
}
@ -103,7 +109,7 @@ namespace Nz
struct CompLe;
template<typename T1, typename T2>
struct PropagateConstantType<ShaderNodes::BinaryType::CompLe, T1, T2>
struct PropagateConstantType<BinaryType::CompLe, T1, T2>
{
using Op = typename CompLe<T1, T2>;
};
@ -112,7 +118,7 @@ namespace Nz
template<typename T1, typename T2>
struct CompLtBase
{
ShaderNodes::ExpressionPtr operator()(const T1& lhs, const T2& rhs)
ExpressionPtr operator()(const T1& lhs, const T2& rhs)
{
return ShaderBuilder::Constant(lhs < rhs);
}
@ -122,7 +128,7 @@ namespace Nz
struct CompLt;
template<typename T1, typename T2>
struct PropagateConstantType<ShaderNodes::BinaryType::CompLt, T1, T2>
struct PropagateConstantType<BinaryType::CompLt, T1, T2>
{
using Op = typename CompLe<T1, T2>;
};
@ -131,7 +137,7 @@ namespace Nz
template<typename T1, typename T2>
struct CompNeBase
{
ShaderNodes::ExpressionPtr operator()(const T1& lhs, const T2& rhs)
ExpressionPtr operator()(const T1& lhs, const T2& rhs)
{
return ShaderBuilder::Constant(lhs != rhs);
}
@ -141,7 +147,7 @@ namespace Nz
struct CompNe;
template<typename T1, typename T2>
struct PropagateConstantType<ShaderNodes::BinaryType::CompNe, T1, T2>
struct PropagateConstantType<BinaryType::CompNe, T1, T2>
{
using Op = typename CompNe<T1, T2>;
};
@ -150,7 +156,7 @@ namespace Nz
template<typename T1, typename T2>
struct AdditionBase
{
ShaderNodes::ExpressionPtr operator()(const T1& lhs, const T2& rhs)
ExpressionPtr operator()(const T1& lhs, const T2& rhs)
{
return ShaderBuilder::Constant(lhs + rhs);
}
@ -160,7 +166,7 @@ namespace Nz
struct Addition;
template<typename T1, typename T2>
struct PropagateConstantType<ShaderNodes::BinaryType::Add, T1, T2>
struct PropagateConstantType<BinaryType::Add, T1, T2>
{
using Op = typename Addition<T1, T2>;
};
@ -169,7 +175,7 @@ namespace Nz
template<typename T1, typename T2>
struct DivisionBase
{
ShaderNodes::ExpressionPtr operator()(const T1& lhs, const T2& rhs)
ExpressionPtr operator()(const T1& lhs, const T2& rhs)
{
return ShaderBuilder::Constant(lhs / rhs);
}
@ -179,7 +185,7 @@ namespace Nz
struct Division;
template<typename T1, typename T2>
struct PropagateConstantType<ShaderNodes::BinaryType::Divide, T1, T2>
struct PropagateConstantType<BinaryType::Divide, T1, T2>
{
using Op = typename Division<T1, T2>;
};
@ -188,7 +194,7 @@ namespace Nz
template<typename T1, typename T2>
struct MultiplicationBase
{
ShaderNodes::ExpressionPtr operator()(const T1& lhs, const T2& rhs)
ExpressionPtr operator()(const T1& lhs, const T2& rhs)
{
return ShaderBuilder::Constant(lhs * rhs);
}
@ -198,7 +204,7 @@ namespace Nz
struct Multiplication;
template<typename T1, typename T2>
struct PropagateConstantType<ShaderNodes::BinaryType::Multiply, T1, T2>
struct PropagateConstantType<BinaryType::Multiply, T1, T2>
{
using Op = typename Multiplication<T1, T2>;
};
@ -207,7 +213,7 @@ namespace Nz
template<typename T1, typename T2>
struct SubtractionBase
{
ShaderNodes::ExpressionPtr operator()(const T1& lhs, const T2& rhs)
ExpressionPtr operator()(const T1& lhs, const T2& rhs)
{
return ShaderBuilder::Constant(lhs - rhs);
}
@ -217,7 +223,7 @@ namespace Nz
struct Subtraction;
template<typename T1, typename T2>
struct PropagateConstantType<ShaderNodes::BinaryType::Subtract, T1, T2>
struct PropagateConstantType<BinaryType::Subtract, T1, T2>
{
using Op = typename Subtraction<T1, T2>;
};
@ -375,92 +381,89 @@ namespace Nz
#undef EnableOptimisation
}
ShaderNodes::StatementPtr ShaderAstOptimizer::Optimise(const ShaderNodes::StatementPtr& statement)
StatementPtr AstOptimizer::Optimise(StatementPtr& statement)
{
m_shaderAst = nullptr;
return CloneStatement(statement);
}
ShaderNodes::StatementPtr ShaderAstOptimizer::Optimise(const ShaderNodes::StatementPtr& statement, const ShaderAst& shader, UInt64 enabledConditions)
StatementPtr AstOptimizer::Optimise(StatementPtr& statement, UInt64 enabledConditions)
{
m_shaderAst = &shader;
m_enabledConditions = enabledConditions;
return CloneStatement(statement);
}
void ShaderAstOptimizer::Visit(ShaderNodes::BinaryOp& node)
void AstOptimizer::Visit(BinaryExpression& node)
{
auto lhs = CloneExpression(node.left);
auto rhs = CloneExpression(node.right);
if (lhs->GetType() == ShaderNodes::NodeType::Constant && rhs->GetType() == ShaderNodes::NodeType::Constant)
if (lhs->GetType() == NodeType::ConstantExpression && rhs->GetType() == NodeType::ConstantExpression)
{
auto lhsConstant = std::static_pointer_cast<ShaderNodes::Constant>(lhs);
auto rhsConstant = std::static_pointer_cast<ShaderNodes::Constant>(rhs);
auto lhsConstant = static_unique_pointer_cast<ConstantExpression>(std::move(lhs));
auto rhsConstant = static_unique_pointer_cast<ConstantExpression>(std::move(rhs));
switch (node.op)
{
case ShaderNodes::BinaryType::Add:
return PropagateConstant<ShaderNodes::BinaryType::Add>(lhsConstant, rhsConstant);
case BinaryType::Add:
return PropagateConstant<BinaryType::Add>(std::move(lhsConstant), std::move(rhsConstant));
case ShaderNodes::BinaryType::Subtract:
return PropagateConstant<ShaderNodes::BinaryType::Subtract>(lhsConstant, rhsConstant);
case BinaryType::Subtract:
return PropagateConstant<BinaryType::Subtract>(std::move(lhsConstant), std::move(rhsConstant));
case ShaderNodes::BinaryType::Multiply:
return PropagateConstant<ShaderNodes::BinaryType::Multiply>(lhsConstant, rhsConstant);
case BinaryType::Multiply:
return PropagateConstant<BinaryType::Multiply>(std::move(lhsConstant), std::move(rhsConstant));
case ShaderNodes::BinaryType::Divide:
return PropagateConstant<ShaderNodes::BinaryType::Divide>(lhsConstant, rhsConstant);
case BinaryType::Divide:
return PropagateConstant<BinaryType::Divide>(std::move(lhsConstant), std::move(rhsConstant));
case ShaderNodes::BinaryType::CompEq:
return PropagateConstant<ShaderNodes::BinaryType::CompEq>(lhsConstant, rhsConstant);
case BinaryType::CompEq:
return PropagateConstant<BinaryType::CompEq>(std::move(lhsConstant), std::move(rhsConstant));
case ShaderNodes::BinaryType::CompGe:
return PropagateConstant<ShaderNodes::BinaryType::CompGe>(lhsConstant, rhsConstant);
case BinaryType::CompGe:
return PropagateConstant<BinaryType::CompGe>(std::move(lhsConstant), std::move(rhsConstant));
case ShaderNodes::BinaryType::CompGt:
return PropagateConstant<ShaderNodes::BinaryType::CompGt>(lhsConstant, rhsConstant);
case BinaryType::CompGt:
return PropagateConstant<BinaryType::CompGt>(std::move(lhsConstant), std::move(rhsConstant));
case ShaderNodes::BinaryType::CompLe:
return PropagateConstant<ShaderNodes::BinaryType::CompLe>(lhsConstant, rhsConstant);
case BinaryType::CompLe:
return PropagateConstant<BinaryType::CompLe>(std::move(lhsConstant), std::move(rhsConstant));
case ShaderNodes::BinaryType::CompLt:
return PropagateConstant<ShaderNodes::BinaryType::CompLt>(lhsConstant, rhsConstant);
case BinaryType::CompLt:
return PropagateConstant<BinaryType::CompLt>(std::move(lhsConstant), std::move(rhsConstant));
case ShaderNodes::BinaryType::CompNe:
return PropagateConstant<ShaderNodes::BinaryType::CompNe>(lhsConstant, rhsConstant);
case BinaryType::CompNe:
return PropagateConstant<BinaryType::CompNe>(std::move(lhsConstant), std::move(rhsConstant));
}
}
ShaderAstCloner::Visit(node);
AstCloner::Visit(node);
}
void ShaderAstOptimizer::Visit(ShaderNodes::Branch& node)
void AstOptimizer::Visit(BranchStatement& node)
{
std::vector<ShaderNodes::Branch::ConditionalStatement> statements;
ShaderNodes::StatementPtr elseStatement;
std::vector<BranchStatement::ConditionalStatement> statements;
StatementPtr elseStatement;
for (auto& condStatement : node.condStatements)
{
auto cond = CloneExpression(condStatement.condition);
if (cond->GetType() == ShaderNodes::NodeType::Constant)
if (cond->GetType() == NodeType::ConstantExpression)
{
auto constant = std::static_pointer_cast<ShaderNodes::Constant>(cond);
auto& constant = static_cast<ConstantExpression&>(*cond);
assert(IsBasicType(cond->GetExpressionType()));
assert(std::get<ShaderNodes::BasicType>(cond->GetExpressionType()) == ShaderNodes::BasicType::Boolean);
assert(IsBasicType(GetExpressionType(constant)));
assert(std::get<BasicType>(GetExpressionType(constant)) == BasicType::Boolean);
bool cValue = std::get<bool>(constant->value);
bool cValue = std::get<bool>(constant.value);
if (!cValue)
continue;
if (statements.empty())
{
// First condition is true, dismiss the branch
Visit(condStatement.statement);
condStatement.statement->Visit(*this);
return;
}
else
@ -482,47 +485,54 @@ namespace Nz
{
// All conditions have been removed, replace by else statement or no-op
if (node.elseStatement)
return Visit(node.elseStatement);
{
node.elseStatement->Visit(*this);
return;
}
else
return PushStatement(ShaderNodes::NoOp::Build());
return PushStatement(ShaderBuilder::NoOp());
}
if (!elseStatement)
elseStatement = CloneStatement(node.elseStatement);
PushStatement(ShaderNodes::Branch::Build(std::move(statements), std::move(elseStatement)));
PushStatement(ShaderBuilder::Branch(std::move(statements), std::move(elseStatement)));
}
void ShaderAstOptimizer::Visit(ShaderNodes::ConditionalExpression& node)
void AstOptimizer::Visit(ConditionalExpression& node)
{
if (!m_shaderAst)
return AstCloner::Visit(node);
/*if (!m_shaderAst)
return ShaderAstCloner::Visit(node);
std::size_t conditionIndex = m_shaderAst->FindConditionByName(node.conditionName);
assert(conditionIndex != ShaderAst::InvalidCondition);
assert(conditionIndex != InvalidCondition);
if (TestBit<Nz::UInt64>(m_enabledConditions, conditionIndex))
Visit(node.truePath);
else
Visit(node.falsePath);
Visit(node.falsePath);*/
}
void ShaderAstOptimizer::Visit(ShaderNodes::ConditionalStatement& node)
void AstOptimizer::Visit(ConditionalStatement& node)
{
if (!m_shaderAst)
return AstCloner::Visit(node);
/*if (!m_shaderAst)
return ShaderAstCloner::Visit(node);
std::size_t conditionIndex = m_shaderAst->FindConditionByName(node.conditionName);
assert(conditionIndex != ShaderAst::InvalidCondition);
assert(conditionIndex != InvalidCondition);
if (TestBit<Nz::UInt64>(m_enabledConditions, conditionIndex))
Visit(node.statement);
Visit(node.statement);*/
}
template<ShaderNodes::BinaryType Type>
void ShaderAstOptimizer::PropagateConstant(const std::shared_ptr<ShaderNodes::Constant>& lhs, const std::shared_ptr<ShaderNodes::Constant>& rhs)
template<BinaryType Type>
void AstOptimizer::PropagateConstant(std::unique_ptr<ConstantExpression>&& lhs, std::unique_ptr<ConstantExpression>&& rhs)
{
ShaderNodes::ExpressionPtr optimized;
ExpressionPtr optimized;
std::visit([&](auto&& arg1)
{
using T1 = std::decay_t<decltype(arg1)>;
@ -543,8 +553,8 @@ namespace Nz
}, lhs->value);
if (optimized)
PushExpression(optimized);
PushExpression(std::move(optimized));
else
PushExpression(ShaderNodes::BinaryOp::Build(Type, lhs, rhs));
PushExpression(ShaderBuilder::Binary(Type, std::move(lhs), std::move(rhs)));
}
}

View File

@ -5,116 +5,121 @@
#include <Nazara/Shader/ShaderAstRecursiveVisitor.hpp>
#include <Nazara/Shader/Debug.hpp>
namespace Nz
namespace Nz::ShaderAst
{
void ShaderAstRecursiveVisitor::Visit(ShaderNodes::AccessMember& node)
void AstRecursiveVisitor::Visit(AccessMemberExpression& node)
{
Visit(node.structExpr);
node.structExpr->Visit(*this);
}
void ShaderAstRecursiveVisitor::Visit(ShaderNodes::AssignOp& node)
void AstRecursiveVisitor::Visit(AssignExpression& node)
{
Visit(node.left);
Visit(node.right);
node.left->Visit(*this);
node.right->Visit(*this);
}
void ShaderAstRecursiveVisitor::Visit(ShaderNodes::BinaryOp& node)
void AstRecursiveVisitor::Visit(BinaryExpression& node)
{
Visit(node.left);
Visit(node.right);
node.left->Visit(*this);
node.right->Visit(*this);
}
void ShaderAstRecursiveVisitor::Visit(ShaderNodes::Branch& node)
{
for (auto& cond : node.condStatements)
{
Visit(cond.condition);
Visit(cond.statement);
}
if (node.elseStatement)
Visit(node.elseStatement);
}
void ShaderAstRecursiveVisitor::Visit(ShaderNodes::Cast& node)
void AstRecursiveVisitor::Visit(CastExpression& node)
{
for (auto& expr : node.expressions)
{
if (!expr)
break;
Visit(expr);
expr->Visit(*this);
}
}
void ShaderAstRecursiveVisitor::Visit(ShaderNodes::ConditionalExpression& node)
void AstRecursiveVisitor::Visit(ConditionalExpression& node)
{
Visit(node.truePath);
Visit(node.falsePath);
node.truePath->Visit(*this);
node.falsePath->Visit(*this);
}
void ShaderAstRecursiveVisitor::Visit(ShaderNodes::ConditionalStatement& node)
{
Visit(node.statement);
}
void ShaderAstRecursiveVisitor::Visit(ShaderNodes::Constant& /*node*/)
void AstRecursiveVisitor::Visit(ConstantExpression& /*node*/)
{
/* Nothing to do */
}
void ShaderAstRecursiveVisitor::Visit(ShaderNodes::DeclareVariable& node)
{
if (node.expression)
Visit(node.expression);
}
void ShaderAstRecursiveVisitor::Visit(ShaderNodes::Discard& /*node*/)
void AstRecursiveVisitor::Visit(IdentifierExpression& /*node*/)
{
/* Nothing to do */
}
void ShaderAstRecursiveVisitor::Visit(ShaderNodes::ExpressionStatement& node)
{
Visit(node.expression);
}
void ShaderAstRecursiveVisitor::Visit(ShaderNodes::Identifier& /*node*/)
{
/* Nothing to do */
}
void ShaderAstRecursiveVisitor::Visit(ShaderNodes::IntrinsicCall& node)
void AstRecursiveVisitor::Visit(IntrinsicExpression& node)
{
for (auto& param : node.parameters)
Visit(param);
param->Visit(*this);
}
void ShaderAstRecursiveVisitor::Visit(ShaderNodes::NoOp& /*node*/)
void AstRecursiveVisitor::Visit(SwizzleExpression& node)
{
node.expression->Visit(*this);
}
void AstRecursiveVisitor::Visit(BranchStatement& node)
{
for (auto& cond : node.condStatements)
{
cond.condition->Visit(*this);
cond.statement->Visit(*this);
}
if (node.elseStatement)
node.elseStatement->Visit(*this);
}
void AstRecursiveVisitor::Visit(ConditionalStatement& node)
{
node.statement->Visit(*this);
}
void AstRecursiveVisitor::Visit(DeclareFunctionStatement& node)
{
for (auto& statement : node.statements)
statement->Visit(*this);
}
void AstRecursiveVisitor::Visit(DeclareStructStatement& /*node*/)
{
/* Nothing to do */
}
void ShaderAstRecursiveVisitor::Visit(ShaderNodes::ReturnStatement& node)
void AstRecursiveVisitor::Visit(DeclareVariableStatement& node)
{
if (node.returnExpr)
Visit(node.returnExpr);
if (node.initialExpression)
node.initialExpression->Visit(*this);
}
void ShaderAstRecursiveVisitor::Visit(ShaderNodes::Sample2D& node)
void AstRecursiveVisitor::Visit(DiscardStatement& /*node*/)
{
Visit(node.sampler);
Visit(node.coordinates);
/* Nothing to do */
}
void ShaderAstRecursiveVisitor::Visit(ShaderNodes::StatementBlock& node)
void AstRecursiveVisitor::Visit(ExpressionStatement& node)
{
node.expression->Visit(*this);
}
void AstRecursiveVisitor::Visit(MultiStatement& node)
{
for (auto& statement : node.statements)
Visit(statement);
statement->Visit(*this);
}
void ShaderAstRecursiveVisitor::Visit(ShaderNodes::SwizzleOp& node)
void AstRecursiveVisitor::Visit(NoOpStatement& /*node*/)
{
Visit(node.expression);
/* Nothing to do */
}
void AstRecursiveVisitor::Visit(ReturnStatement& node)
{
if (node.returnExpr)
node.returnExpr->Visit(*this);
}
}

View File

@ -3,221 +3,74 @@
// For conditions of distribution and use, see copyright notice in Config.hpp
#include <Nazara/Shader/ShaderAstSerializer.hpp>
#include <Nazara/Shader/ShaderVarVisitor.hpp>
#include <Nazara/Shader/ShaderAstVisitor.hpp>
#include <Nazara/Shader/ShaderAstExpressionVisitor.hpp>
#include <Nazara/Shader/ShaderAstStatementVisitor.hpp>
#include <Nazara/Shader/Debug.hpp>
namespace Nz
namespace Nz::ShaderAst
{
namespace
{
constexpr UInt32 s_magicNumber = 0x4E534852;
constexpr UInt32 s_currentVersion = 1;
class ShaderSerializerVisitor : public ShaderAstVisitor, public ShaderVarVisitor
class ShaderSerializerVisitor : public AstExpressionVisitor, public AstStatementVisitor
{
public:
ShaderSerializerVisitor(ShaderAstSerializerBase& serializer) :
ShaderSerializerVisitor(AstSerializerBase& serializer) :
m_serializer(serializer)
{
}
void Visit(ShaderNodes::AccessMember& node) override
{
Serialize(node);
}
void Visit(ShaderNodes::AssignOp& node) override
{
Serialize(node);
}
void Visit(ShaderNodes::BinaryOp& node) override
{
Serialize(node);
}
void Visit(ShaderNodes::Branch& node) override
{
Serialize(node);
}
void Visit(ShaderNodes::Cast& node) override
{
Serialize(node);
}
void Visit(ShaderNodes::ConditionalExpression& node) override
{
Serialize(node);
}
void Visit(ShaderNodes::ConditionalStatement& node) override
{
Serialize(node);
}
void Visit(ShaderNodes::Constant& node) override
{
Serialize(node);
}
void Visit(ShaderNodes::DeclareVariable& node) override
{
Serialize(node);
}
void Visit(ShaderNodes::Discard& node) override
{
Serialize(node);
}
void Visit(ShaderNodes::ExpressionStatement& node) override
{
Serialize(node);
}
void Visit(ShaderNodes::Identifier& node) override
{
Serialize(node);
}
void Visit(ShaderNodes::IntrinsicCall& node) override
{
Serialize(node);
}
void Visit(ShaderNodes::NoOp& node) override
{
Serialize(node);
}
void Visit(ShaderNodes::ReturnStatement& node) override
{
Serialize(node);
}
void Visit(ShaderNodes::Sample2D& node) override
{
Serialize(node);
}
void Visit(ShaderNodes::StatementBlock& node) override
{
Serialize(node);
}
void Visit(ShaderNodes::SwizzleOp& node) override
{
Serialize(node);
}
void Visit(ShaderNodes::BuiltinVariable& var) override
{
Serialize(var);
}
void Visit(ShaderNodes::InputVariable& var) override
{
Serialize(var);
}
void Visit(ShaderNodes::LocalVariable& var) override
{
Serialize(var);
}
void Visit(ShaderNodes::OutputVariable& var) override
{
Serialize(var);
}
void Visit(ShaderNodes::ParameterVariable& var) override
{
Serialize(var);
}
void Visit(ShaderNodes::UniformVariable& var) override
{
Serialize(var);
#define NAZARA_SHADERAST_NODE(Node) void Visit(Node& node) override \
{ \
m_serializer.Serialize(node); \
}
#include <Nazara/Shader/ShaderAstNodes.hpp>
private:
template<typename T>
void Serialize(const T& node)
{
// I know const_cast is evil but I don't have a better solution here (it's not used to write)
m_serializer.Serialize(const_cast<T&>(node));
}
ShaderAstSerializerBase& m_serializer;
AstSerializerBase& m_serializer;
};
}
void ShaderAstSerializerBase::Serialize(ShaderNodes::AccessMember& node)
void AstSerializerBase::Serialize(AccessMemberExpression& node)
{
Node(node.structExpr);
Type(node.exprType);
Container(node.memberIndices);
for (std::size_t& index : node.memberIndices)
SizeT(index);
Container(node.memberIdentifiers);
for (std::string& identifier : node.memberIdentifiers)
Value(identifier);
}
void ShaderAstSerializerBase::Serialize(ShaderNodes::AssignOp& node)
void AstSerializerBase::Serialize(AssignExpression& node)
{
Enum(node.op);
Node(node.left);
Node(node.right);
}
void ShaderAstSerializerBase::Serialize(ShaderNodes::BinaryOp& node)
void AstSerializerBase::Serialize(BinaryExpression& node)
{
Enum(node.op);
Node(node.left);
Node(node.right);
}
void ShaderAstSerializerBase::Serialize(ShaderNodes::Branch& node)
void AstSerializerBase::Serialize(CastExpression& node)
{
Container(node.condStatements);
for (auto& condStatement : node.condStatements)
{
Node(condStatement.condition);
Node(condStatement.statement);
}
Node(node.elseStatement);
}
void ShaderAstSerializerBase::Serialize(ShaderNodes::BuiltinVariable& node)
{
Enum(node.entry);
Type(node.type);
}
void ShaderAstSerializerBase::Serialize(ShaderNodes::Cast& node)
{
Enum(node.exprType);
Enum(node.targetType);
for (auto& expr : node.expressions)
Node(expr);
}
void ShaderAstSerializerBase::Serialize(ShaderNodes::ConditionalExpression& node)
void AstSerializerBase::Serialize(ConditionalExpression& node)
{
Value(node.conditionName);
Node(node.truePath);
Node(node.falsePath);
}
void ShaderAstSerializerBase::Serialize(ShaderNodes::ConditionalStatement& node)
{
Value(node.conditionName);
Node(node.statement);
}
void ShaderAstSerializerBase::Serialize(ShaderNodes::Constant& node)
void AstSerializerBase::Serialize(ConstantExpression& node)
{
UInt32 typeIndex;
if (IsWriting())
@ -251,28 +104,19 @@ namespace Nz
}
}
void ShaderAstSerializerBase::Serialize(ShaderNodes::DeclareVariable& node)
void AstSerializerBase::Serialize(DeclareVariableStatement& node)
{
Variable(node.variable);
Node(node.expression);
Value(node.varName);
Type(node.varType);
Node(node.initialExpression);
}
void ShaderAstSerializerBase::Serialize(ShaderNodes::Discard& /*node*/)
void AstSerializerBase::Serialize(IdentifierExpression& node)
{
/* Nothing to do */
Value(node.identifier);
}
void ShaderAstSerializerBase::Serialize(ShaderNodes::ExpressionStatement& node)
{
Node(node.expression);
}
void ShaderAstSerializerBase::Serialize(ShaderNodes::Identifier& node)
{
Variable(node.var);
}
void ShaderAstSerializerBase::Serialize(ShaderNodes::IntrinsicCall& node)
void AstSerializerBase::Serialize(IntrinsicExpression& node)
{
Enum(node.intrinsic);
Container(node.parameters);
@ -280,36 +124,7 @@ namespace Nz
Node(param);
}
void ShaderAstSerializerBase::Serialize(ShaderNodes::NamedVariable& node)
{
Value(node.name);
Type(node.type);
}
void ShaderAstSerializerBase::Serialize(ShaderNodes::NoOp& /*node*/)
{
/* Nothing to do */
}
void ShaderAstSerializerBase::Serialize(ShaderNodes::ReturnStatement& node)
{
Node(node.returnExpr);
}
void ShaderAstSerializerBase::Serialize(ShaderNodes::Sample2D& node)
{
Node(node.sampler);
Node(node.coordinates);
}
void ShaderAstSerializerBase::Serialize(ShaderNodes::StatementBlock& node)
{
Container(node.statements);
for (auto& statement : node.statements)
Node(statement);
}
void ShaderAstSerializerBase::Serialize(ShaderNodes::SwizzleOp& node)
void AstSerializerBase::Serialize(SwizzleExpression& node)
{
SizeT(node.componentCount);
Node(node.expression);
@ -319,100 +134,85 @@ namespace Nz
}
void ShaderAstSerializer::Serialize(const ShaderAst& shader)
void AstSerializerBase::Serialize(BranchStatement& node)
{
Container(node.condStatements);
for (auto& condStatement : node.condStatements)
{
Node(condStatement.condition);
Node(condStatement.statement);
}
Node(node.elseStatement);
}
void AstSerializerBase::Serialize(ConditionalStatement& node)
{
Value(node.conditionName);
Node(node.statement);
}
void AstSerializerBase::Serialize(DeclareFunctionStatement& node)
{
Value(node.name);
Type(node.returnType);
Container(node.parameters);
for (auto& parameter : node.parameters)
{
Value(parameter.name);
Type(parameter.type);
}
Container(node.statements);
for (auto& statement : node.statements)
Node(statement);
}
void AstSerializerBase::Serialize(DeclareStructStatement& node)
{
Value(node.description.name);
Container(node.description.members);
for (auto& member : node.description.members)
{
Value(member.name);
Type(member.type);
}
}
void AstSerializerBase::Serialize(DiscardStatement& /*node*/)
{
/* Nothing to do */
}
void AstSerializerBase::Serialize(ExpressionStatement& node)
{
Node(node.expression);
}
void AstSerializerBase::Serialize(MultiStatement& node)
{
Container(node.statements);
for (auto& statement : node.statements)
Node(statement);
}
void AstSerializerBase::Serialize(NoOpStatement& /*node*/)
{
/* Nothing to do */
}
void AstSerializerBase::Serialize(ReturnStatement& node)
{
Node(node.returnExpr);
}
void ShaderAstSerializer::Serialize(StatementPtr& shader)
{
m_stream << s_magicNumber << s_currentVersion;
m_stream << UInt32(shader.GetStage());
auto SerializeType = [&](const ShaderExpressionType& type)
{
std::visit([&](auto&& arg)
{
using T = std::decay_t<decltype(arg)>;
if constexpr (std::is_same_v<T, ShaderNodes::BasicType>)
{
m_stream << UInt8(0);
m_stream << UInt32(arg);
}
else if constexpr (std::is_same_v<T, std::string>)
{
m_stream << UInt8(1);
m_stream << arg;
}
else
static_assert(AlwaysFalse<T>::value, "non-exhaustive visitor");
}, type);
};
auto SerializeInputOutput = [&](auto& inout)
{
m_stream << UInt32(inout.size());
for (const auto& data : inout)
{
m_stream << data.name;
SerializeType(data.type);
m_stream << data.locationIndex.has_value();
if (data.locationIndex)
m_stream << UInt32(data.locationIndex.value());
}
};
// Conditions
m_stream << UInt32(shader.GetConditionCount());
for (const auto& cond : shader.GetConditions())
m_stream << cond.name;
// Structs
m_stream << UInt32(shader.GetStructCount());
for (const auto& s : shader.GetStructs())
{
m_stream << s.name;
m_stream << UInt32(s.members.size());
for (const auto& member : s.members)
{
m_stream << member.name;
SerializeType(member.type);
}
}
// Inputs / Outputs
SerializeInputOutput(shader.GetInputs());
SerializeInputOutput(shader.GetOutputs());
// Uniforms
m_stream << UInt32(shader.GetUniformCount());
for (const auto& uniform : shader.GetUniforms())
{
m_stream << uniform.name;
SerializeType(uniform.type);
m_stream << uniform.bindingIndex.has_value();
if (uniform.bindingIndex)
m_stream << UInt32(uniform.bindingIndex.value());
m_stream << uniform.memoryLayout.has_value();
if (uniform.memoryLayout)
m_stream << UInt32(uniform.memoryLayout.value());
}
// Functions
m_stream << UInt32(shader.GetFunctionCount());
for (const auto& func : shader.GetFunctions())
{
m_stream << func.name;
SerializeType(func.returnType);
m_stream << UInt32(func.parameters.size());
for (const auto& param : func.parameters)
{
m_stream << param.name;
SerializeType(param.type);
}
Node(func.statement);
}
Node(shader);
m_stream.FlushBits();
}
@ -422,9 +222,21 @@ namespace Nz
return true;
}
void ShaderAstSerializer::Node(ShaderNodes::NodePtr& node)
void ShaderAstSerializer::Node(ExpressionPtr& node)
{
ShaderNodes::NodeType nodeType = (node) ? node->GetType() : ShaderNodes::NodeType::None;
NodeType nodeType = (node) ? node->GetType() : NodeType::None;
m_stream << static_cast<Int32>(nodeType);
if (node)
{
ShaderSerializerVisitor visitor(*this);
node->Visit(visitor);
}
}
void ShaderAstSerializer::Node(StatementPtr& node)
{
NodeType nodeType = (node) ? node->GetType() : NodeType::None;
m_stream << static_cast<Int32>(nodeType);
if (node)
@ -439,7 +251,7 @@ namespace Nz
std::visit([&](auto&& arg)
{
using T = std::decay_t<decltype(arg)>;
if constexpr (std::is_same_v<T, ShaderNodes::BasicType>)
if constexpr (std::is_same_v<T, BasicType>)
{
m_stream << UInt8(0);
m_stream << UInt32(arg);
@ -454,11 +266,6 @@ namespace Nz
}, type);
}
void ShaderAstSerializer::Node(const ShaderNodes::NodePtr& node)
{
Node(const_cast<ShaderNodes::NodePtr&>(node)); //< Yes const_cast is ugly but it won't be used for writing
}
void ShaderAstSerializer::Value(bool& val)
{
m_stream << val;
@ -529,19 +336,7 @@ namespace Nz
m_stream << val;
}
void ShaderAstSerializer::Variable(ShaderNodes::VariablePtr& var)
{
ShaderNodes::VariableType nodeType = (var) ? var->GetType() : ShaderNodes::VariableType::None;
m_stream << static_cast<Int32>(nodeType);
if (var)
{
ShaderSerializerVisitor visitor(*this);
var->Visit(visitor);
}
}
ShaderAst ShaderAstUnserializer::Unserialize()
StatementPtr ShaderAstUnserializer::Unserialize()
{
UInt32 magicNumber;
UInt32 version;
@ -553,122 +348,13 @@ namespace Nz
if (version > s_currentVersion)
throw std::runtime_error("unsupported version");
UInt32 shaderStage;
m_stream >> shaderStage;
StatementPtr node;
ShaderAst shader(static_cast<ShaderStageType>(shaderStage));
Node(node);
if (!node)
throw std::runtime_error("functions can only have statements");
// Conditions
UInt32 conditionCount;
m_stream >> conditionCount;
for (UInt32 i = 0; i < conditionCount; ++i)
{
std::string conditionName;
Value(conditionName);
shader.AddCondition(std::move(conditionName));
}
// Structs
UInt32 structCount;
m_stream >> structCount;
for (UInt32 i = 0; i < structCount; ++i)
{
std::string structName;
std::vector<ShaderAst::StructMember> members;
Value(structName);
Container(members);
for (auto& member : members)
{
Value(member.name);
Type(member.type);
}
shader.AddStruct(std::move(structName), std::move(members));
}
// Inputs
UInt32 inputCount;
m_stream >> inputCount;
for (UInt32 i = 0; i < inputCount; ++i)
{
std::string inputName;
ShaderExpressionType inputType;
std::optional<std::size_t> location;
Value(inputName);
Type(inputType);
OptVal(location);
shader.AddInput(std::move(inputName), std::move(inputType), location);
}
// Outputs
UInt32 outputCount;
m_stream >> outputCount;
for (UInt32 i = 0; i < outputCount; ++i)
{
std::string outputName;
ShaderExpressionType outputType;
std::optional<std::size_t> location;
Value(outputName);
Type(outputType);
OptVal(location);
shader.AddOutput(std::move(outputName), std::move(outputType), location);
}
// Uniforms
UInt32 uniformCount;
m_stream >> uniformCount;
for (UInt32 i = 0; i < uniformCount; ++i)
{
std::string name;
ShaderExpressionType type;
std::optional<std::size_t> binding;
std::optional<ShaderNodes::MemoryLayout> memLayout;
Value(name);
Type(type);
OptVal(binding);
OptEnum(memLayout);
shader.AddUniform(std::move(name), std::move(type), std::move(binding), std::move(memLayout));
}
// Functions
UInt32 funcCount;
m_stream >> funcCount;
for (UInt32 i = 0; i < funcCount; ++i)
{
std::string name;
ShaderExpressionType retType;
std::vector<ShaderAst::FunctionParameter> parameters;
Value(name);
Type(retType);
Container(parameters);
for (auto& param : parameters)
{
Value(param.name);
Type(param.type);
}
ShaderNodes::NodePtr node;
Node(node);
if (!node || !node->IsStatement())
throw std::runtime_error("functions can only have statements");
ShaderNodes::StatementPtr statement = std::static_pointer_cast<ShaderNodes::Statement>(node);
shader.AddFunction(std::move(name), std::move(statement), std::move(parameters), std::move(retType));
}
return shader;
return node;
}
bool ShaderAstUnserializer::IsWriting() const
@ -676,41 +362,50 @@ namespace Nz
return false;
}
void ShaderAstUnserializer::Node(ShaderNodes::NodePtr& node)
void ShaderAstUnserializer::Node(ExpressionPtr& node)
{
Int32 nodeTypeInt;
m_stream >> nodeTypeInt;
if (nodeTypeInt < static_cast<Int32>(ShaderNodes::NodeType::None) || nodeTypeInt > static_cast<Int32>(ShaderNodes::NodeType::Max))
if (nodeTypeInt < static_cast<Int32>(NodeType::None) || nodeTypeInt > static_cast<Int32>(NodeType::Max))
throw std::runtime_error("invalid node type");
ShaderNodes::NodeType nodeType = static_cast<ShaderNodes::NodeType>(nodeTypeInt);
#define HandleType(Type) case ShaderNodes::NodeType:: Type : node = std::make_shared<ShaderNodes:: Type>(); break
NodeType nodeType = static_cast<NodeType>(nodeTypeInt);
switch (nodeType)
{
case ShaderNodes::NodeType::None: break;
case NodeType::None: break;
HandleType(AccessMember);
HandleType(AssignOp);
HandleType(BinaryOp);
HandleType(Branch);
HandleType(Cast);
HandleType(Constant);
HandleType(ConditionalExpression);
HandleType(ConditionalStatement);
HandleType(DeclareVariable);
HandleType(Discard);
HandleType(ExpressionStatement);
HandleType(Identifier);
HandleType(IntrinsicCall);
HandleType(NoOp);
HandleType(ReturnStatement);
HandleType(Sample2D);
HandleType(SwizzleOp);
HandleType(StatementBlock);
#define NAZARA_SHADERAST_EXPRESSION(Node) case NodeType:: Node : node = std::make_unique<Node>(); break;
#include <Nazara/Shader/ShaderAstNodes.hpp>
default: throw std::runtime_error("unexpected node type");
}
if (node)
{
ShaderSerializerVisitor visitor(*this);
node->Visit(visitor);
}
}
void ShaderAstUnserializer::Node(StatementPtr& node)
{
Int32 nodeTypeInt;
m_stream >> nodeTypeInt;
if (nodeTypeInt < static_cast<Int32>(NodeType::None) || nodeTypeInt > static_cast<Int32>(NodeType::Max))
throw std::runtime_error("invalid node type");
NodeType nodeType = static_cast<NodeType>(nodeTypeInt);
switch (nodeType)
{
case NodeType::None: break;
#define NAZARA_SHADERAST_STATEMENT(Node) case NodeType:: Node : node = std::make_unique<Node>(); break;
#include <Nazara/Shader/ShaderAstNodes.hpp>
default: throw std::runtime_error("unexpected node type");
}
#undef HandleType
if (node)
{
@ -728,7 +423,7 @@ namespace Nz
{
case 0: //< Primitive
{
ShaderNodes::BasicType exprType;
BasicType exprType;
Enum(exprType);
type = exprType;
@ -819,36 +514,8 @@ namespace Nz
m_stream >> val;
}
void ShaderAstUnserializer::Variable(ShaderNodes::VariablePtr& var)
{
Int32 nodeTypeInt;
m_stream >> nodeTypeInt;
ShaderNodes::VariableType nodeType = static_cast<ShaderNodes:: VariableType>(nodeTypeInt);
#define HandleType(Type) case ShaderNodes::VariableType:: Type : var = std::make_shared<ShaderNodes::Type>(); break
switch (nodeType)
{
case ShaderNodes::VariableType::None: break;
HandleType(BuiltinVariable);
HandleType(InputVariable);
HandleType(LocalVariable);
HandleType(ParameterVariable);
HandleType(OutputVariable);
HandleType(UniformVariable);
}
#undef HandleType
if (var)
{
ShaderSerializerVisitor visitor(*this);
var->Visit(visitor);
}
}
ByteArray SerializeShader(const ShaderAst& shader)
ByteArray SerializeShader(StatementPtr& shader)
{
ByteArray byteArray;
ByteStream stream(&byteArray, OpenModeFlags(OpenMode_WriteOnly));
@ -859,7 +526,7 @@ namespace Nz
return byteArray;
}
ShaderAst UnserializeShader(ByteStream& stream)
StatementPtr UnserializeShader(ByteStream& stream)
{
ShaderAstUnserializer unserializer(stream);
return unserializer.Unserialize();

View File

@ -2,15 +2,10 @@
// This file is part of the "Nazara Engine - Shader generator"
// For conditions of distribution and use, see copyright notice in Config.hpp
#include <Nazara/Shader/ShaderVarVisitor.hpp>
#include <Nazara/Shader/ShaderAstStatementVisitor.hpp>
#include <Nazara/Shader/Debug.hpp>
namespace Nz
namespace Nz::ShaderAst
{
ShaderVarVisitor::~ShaderVarVisitor() = default;
void ShaderVarVisitor::Visit(const ShaderNodes::VariablePtr& node)
{
node->Visit(*this);
}
AstStatementVisitor::~AstStatementVisitor() = default;
}

View File

@ -0,0 +1,15 @@
// Copyright (C) 2020 Jérôme Leclercq
// This file is part of the "Nazara Engine - Shader generator"
// For conditions of distribution and use, see copyright notice in Config.hpp
#include <Nazara/Shader/ShaderAstStatementVisitorExcept.hpp>
#include <Nazara/Shader/Debug.hpp>
namespace Nz::ShaderAst
{
#define NAZARA_SHADERAST_STATEMENT(Node) void StatementVisitorExcept::Visit(ShaderAst::Node& /*node*/) \
{ \
throw std::runtime_error("unexpected " #Node " node"); \
}
#include <Nazara/Shader/ShaderAstNodes.hpp>
}

View File

@ -5,69 +5,65 @@
#include <Nazara/Shader/ShaderAstUtils.hpp>
#include <Nazara/Shader/Debug.hpp>
namespace Nz
namespace Nz::ShaderAst
{
ShaderNodes::ExpressionCategory ShaderAstValueCategory::GetExpressionCategory(const ShaderNodes::ExpressionPtr& expression)
ExpressionCategory ShaderAstValueCategory::GetExpressionCategory(Expression& expression)
{
Visit(expression);
expression.Visit(*this);
return m_expressionCategory;
}
void ShaderAstValueCategory::Visit(ShaderNodes::AccessMember& node)
void ShaderAstValueCategory::Visit(AccessMemberExpression& node)
{
Visit(node.structExpr);
node.structExpr->Visit(*this);
}
void ShaderAstValueCategory::Visit(ShaderNodes::AssignOp& node)
void ShaderAstValueCategory::Visit(AssignExpression& /*node*/)
{
m_expressionCategory = ShaderNodes::ExpressionCategory::RValue;
m_expressionCategory = ExpressionCategory::RValue;
}
void ShaderAstValueCategory::Visit(ShaderNodes::BinaryOp& node)
void ShaderAstValueCategory::Visit(BinaryExpression& /*node*/)
{
m_expressionCategory = ShaderNodes::ExpressionCategory::RValue;
m_expressionCategory = ExpressionCategory::RValue;
}
void ShaderAstValueCategory::Visit(ShaderNodes::Cast& node)
void ShaderAstValueCategory::Visit(CastExpression& /*node*/)
{
m_expressionCategory = ShaderNodes::ExpressionCategory::RValue;
m_expressionCategory = ExpressionCategory::RValue;
}
void ShaderAstValueCategory::Visit(ShaderNodes::ConditionalExpression& node)
void ShaderAstValueCategory::Visit(ConditionalExpression& node)
{
Visit(node.truePath);
ShaderNodes::ExpressionCategory trueExprCategory = m_expressionCategory;
Visit(node.falsePath);
ShaderNodes::ExpressionCategory falseExprCategory = m_expressionCategory;
node.truePath->Visit(*this);
ExpressionCategory trueExprCategory = m_expressionCategory;
if (trueExprCategory == ShaderNodes::ExpressionCategory::RValue || falseExprCategory == ShaderNodes::ExpressionCategory::RValue)
m_expressionCategory = ShaderNodes::ExpressionCategory::RValue;
node.falsePath->Visit(*this);
ExpressionCategory falseExprCategory = m_expressionCategory;
if (trueExprCategory == ExpressionCategory::RValue || falseExprCategory == ExpressionCategory::RValue)
m_expressionCategory = ExpressionCategory::RValue;
else
m_expressionCategory = ShaderNodes::ExpressionCategory::LValue;
m_expressionCategory = ExpressionCategory::LValue;
}
void ShaderAstValueCategory::Visit(ShaderNodes::Constant& node)
void ShaderAstValueCategory::Visit(ConstantExpression& /*node*/)
{
m_expressionCategory = ShaderNodes::ExpressionCategory::RValue;
m_expressionCategory = ExpressionCategory::RValue;
}
void ShaderAstValueCategory::Visit(ShaderNodes::Identifier& node)
void ShaderAstValueCategory::Visit(IdentifierExpression& /*node*/)
{
m_expressionCategory = ShaderNodes::ExpressionCategory::LValue;
m_expressionCategory = ExpressionCategory::LValue;
}
void ShaderAstValueCategory::Visit(ShaderNodes::IntrinsicCall& node)
void ShaderAstValueCategory::Visit(IntrinsicExpression& /*node*/)
{
m_expressionCategory = ShaderNodes::ExpressionCategory::RValue;
m_expressionCategory = ExpressionCategory::RValue;
}
void ShaderAstValueCategory::Visit(ShaderNodes::Sample2D& node)
void ShaderAstValueCategory::Visit(SwizzleExpression& node)
{
m_expressionCategory = ShaderNodes::ExpressionCategory::RValue;
}
void ShaderAstValueCategory::Visit(ShaderNodes::SwizzleOp& node)
{
Visit(node.expression);
node.expression->Visit(*this);
}
}

View File

@ -4,48 +4,40 @@
#include <Nazara/Shader/ShaderAstValidator.hpp>
#include <Nazara/Core/CallOnExit.hpp>
#include <Nazara/Shader/ShaderAst.hpp>
#include <Nazara/Shader/ShaderAstUtils.hpp>
#include <Nazara/Shader/ShaderVariables.hpp>
#include <Nazara/Shader/ShaderAstExpressionType.hpp>
#include <vector>
#include <Nazara/Shader/Debug.hpp>
namespace Nz
namespace Nz::ShaderAst
{
struct AstError
{
std::string errMsg;
};
struct ShaderAstValidator::Context
struct AstValidator::Context
{
struct Local
{
std::string name;
ShaderExpressionType type;
};
const ShaderAst::Function* currentFunction;
std::vector<Local> declaredLocals;
std::vector<std::size_t> blockLocalIndex;
//const ShaderAst::Function* currentFunction;
std::optional<std::size_t> activeScopeId;
AstCache* cache;
};
bool ShaderAstValidator::Validate(std::string* error)
bool AstValidator::Validate(StatementPtr& node, std::string* error, AstCache* cache)
{
try
{
for (std::size_t i = 0; i < m_shader.GetFunctionCount(); ++i)
{
const auto& func = m_shader.GetFunction(i);
AstCache dummy;
Context currentContext;
currentContext.currentFunction = &func;
Context currentContext;
currentContext.cache = (cache) ? cache : &dummy;
m_context = &currentContext;
CallOnExit resetContext([&] { m_context = nullptr; });
m_context = &currentContext;
CallOnExit resetContext([&] { m_context = nullptr; });
func.statement->Visit(*this);
}
EnterScope();
node->Visit(*this);
ExitScope();
return true;
}
@ -58,148 +50,183 @@ namespace Nz
}
}
const ShaderNodes::ExpressionPtr& ShaderAstValidator::MandatoryExpr(const ShaderNodes::ExpressionPtr& node)
{
MandatoryNode(node);
return node;
}
const ShaderNodes::NodePtr& ShaderAstValidator::MandatoryNode(const ShaderNodes::NodePtr& node)
Expression& AstValidator::MandatoryExpr(ExpressionPtr& node)
{
if (!node)
throw AstError{ "Invalid node" };
throw AstError{ "Invalid expression" };
return node;
return *node;
}
void ShaderAstValidator::TypeMustMatch(const ShaderNodes::ExpressionPtr& left, const ShaderNodes::ExpressionPtr& right)
Statement& AstValidator::MandatoryStatement(StatementPtr& node)
{
return TypeMustMatch(left->GetExpressionType(), right->GetExpressionType());
if (!node)
throw AstError{ "Invalid statement" };
return *node;
}
void ShaderAstValidator::TypeMustMatch(const ShaderExpressionType& left, const ShaderExpressionType& right)
void AstValidator::TypeMustMatch(ExpressionPtr& left, ExpressionPtr& right)
{
return TypeMustMatch(GetExpressionType(*left, m_context->cache), GetExpressionType(*right, m_context->cache));
}
void AstValidator::TypeMustMatch(const ShaderExpressionType& left, const ShaderExpressionType& right)
{
if (left != right)
throw AstError{ "Left expression type must match right expression type" };
}
const ShaderAst::StructMember& ShaderAstValidator::CheckField(const std::string& structName, std::size_t* memberIndex, std::size_t remainingMembers)
ShaderExpressionType AstValidator::CheckField(const std::string& structName, const std::string* memberIdentifier, std::size_t remainingMembers)
{
const auto& structs = m_shader.GetStructs();
auto it = std::find_if(structs.begin(), structs.end(), [&](const auto& s) { return s.name == structName; });
if (it == structs.end())
throw AstError{ "invalid structure" };
const AstCache::Identifier* identifier = m_context->cache->FindIdentifier(*m_context->activeScopeId, structName);
if (!identifier)
throw AstError{ "unknown identifier " + structName };
const ShaderAst::Struct& s = *it;
if (*memberIndex >= s.members.size())
throw AstError{ "member index out of bounds" };
if (std::holds_alternative<StructDescription>(identifier->value))
throw AstError{ "identifier is not a struct" };
const auto& member = s.members[*memberIndex];
const StructDescription& s = std::get<StructDescription>(identifier->value);
auto memberIt = std::find_if(s.members.begin(), s.members.begin(), [&](const auto& field) { return field.name == memberIdentifier[0]; });
if (memberIt == s.members.end())
throw AstError{ "unknown field " + memberIdentifier[0]};
const auto& member = *memberIt;
if (remainingMembers > 1)
{
if (!IsStructType(member.type))
throw AstError{ "member type does not match node type" };
return CheckField(std::get<std::string>(member.type), memberIndex + 1, remainingMembers - 1);
}
return CheckField(std::get<std::string>(member.type), memberIdentifier + 1, remainingMembers - 1);
else
return member;
return member.type;
}
void ShaderAstValidator::Visit(ShaderNodes::AccessMember& node)
AstCache::Scope& AstValidator::EnterScope()
{
const ShaderExpressionType& exprType = MandatoryExpr(node.structExpr)->GetExpressionType();
std::size_t newScopeId = m_context->cache->scopes.size();
std::optional<std::size_t> previousScope = m_context->activeScopeId;
auto& newScope = m_context->cache->scopes.emplace_back();
newScope.parentScopeIndex = previousScope;
m_context->activeScopeId = newScopeId;
return m_context->cache->scopes[newScopeId];
}
void AstValidator::ExitScope()
{
assert(m_context->activeScopeId);
auto& previousScope = m_context->cache->scopes[*m_context->activeScopeId];
m_context->activeScopeId = previousScope.parentScopeIndex;
}
void AstValidator::RegisterExpressionType(Expression& node, ShaderExpressionType expressionType)
{
m_context->cache->nodeExpressionType[&node] = std::move(expressionType);
}
void AstValidator::RegisterScope(Node& node)
{
if (m_context->activeScopeId)
m_context->cache->scopeIdByNode[&node] = *m_context->activeScopeId;
}
void AstValidator::Visit(AccessMemberExpression& node)
{
RegisterScope(node);
const ShaderExpressionType& exprType = GetExpressionType(MandatoryExpr(node.structExpr), m_context->cache);
if (!IsStructType(exprType))
throw AstError{ "expression is not a structure" };
const std::string& structName = std::get<std::string>(exprType);
const auto& member = CheckField(structName, node.memberIndices.data(), node.memberIndices.size());
if (member.type != node.exprType)
throw AstError{ "member type does not match node type" };
RegisterExpressionType(node, CheckField(structName, node.memberIdentifiers.data(), node.memberIdentifiers.size()));
}
void ShaderAstValidator::Visit(ShaderNodes::AssignOp& node)
void AstValidator::Visit(AssignExpression& node)
{
MandatoryNode(node.left);
MandatoryNode(node.right);
RegisterScope(node);
MandatoryExpr(node.left);
MandatoryExpr(node.right);
TypeMustMatch(node.left, node.right);
if (GetExpressionCategory(node.left) != ShaderNodes::ExpressionCategory::LValue)
if (GetExpressionCategory(*node.left) != ExpressionCategory::LValue)
throw AstError { "Assignation is only possible with a l-value" };
ShaderAstRecursiveVisitor::Visit(node);
AstRecursiveVisitor::Visit(node);
}
void ShaderAstValidator::Visit(ShaderNodes::BinaryOp& node)
void AstValidator::Visit(BinaryExpression& node)
{
MandatoryNode(node.left);
MandatoryNode(node.right);
RegisterScope(node);
const ShaderExpressionType& leftExprType = MandatoryExpr(node.left)->GetExpressionType();
// Register expression type
AstRecursiveVisitor::Visit(node);
const ShaderExpressionType& leftExprType = GetExpressionType(MandatoryExpr(node.left), m_context->cache);
if (!IsBasicType(leftExprType))
throw AstError{ "left expression type does not support binary operation" };
const ShaderExpressionType& rightExprType = MandatoryExpr(node.right)->GetExpressionType();
const ShaderExpressionType& rightExprType = GetExpressionType(MandatoryExpr(node.right), m_context->cache);
if (!IsBasicType(rightExprType))
throw AstError{ "right expression type does not support binary operation" };
ShaderNodes::BasicType leftType = std::get<ShaderNodes::BasicType>(leftExprType);
ShaderNodes::BasicType rightType = std::get<ShaderNodes::BasicType>(rightExprType);
BasicType leftType = std::get<BasicType>(leftExprType);
BasicType rightType = std::get<BasicType>(rightExprType);
switch (node.op)
{
case ShaderNodes::BinaryType::CompGe:
case ShaderNodes::BinaryType::CompGt:
case ShaderNodes::BinaryType::CompLe:
case ShaderNodes::BinaryType::CompLt:
if (leftType == ShaderNodes::BasicType::Boolean)
case BinaryType::CompGe:
case BinaryType::CompGt:
case BinaryType::CompLe:
case BinaryType::CompLt:
if (leftType == BasicType::Boolean)
throw AstError{ "this operation is not supported for booleans" };
[[fallthrough]];
case ShaderNodes::BinaryType::Add:
case ShaderNodes::BinaryType::CompEq:
case ShaderNodes::BinaryType::CompNe:
case ShaderNodes::BinaryType::Subtract:
case BinaryType::Add:
case BinaryType::CompEq:
case BinaryType::CompNe:
case BinaryType::Subtract:
TypeMustMatch(node.left, node.right);
break;
case ShaderNodes::BinaryType::Multiply:
case ShaderNodes::BinaryType::Divide:
case BinaryType::Multiply:
case BinaryType::Divide:
{
switch (leftType)
{
case ShaderNodes::BasicType::Float1:
case ShaderNodes::BasicType::Int1:
case BasicType::Float1:
case BasicType::Int1:
{
if (ShaderNodes::Node::GetComponentType(rightType) != leftType)
if (GetComponentType(rightType) != leftType)
throw AstError{ "Left expression type is not compatible with right expression type" };
break;
}
case ShaderNodes::BasicType::Float2:
case ShaderNodes::BasicType::Float3:
case ShaderNodes::BasicType::Float4:
case ShaderNodes::BasicType::Int2:
case ShaderNodes::BasicType::Int3:
case ShaderNodes::BasicType::Int4:
case BasicType::Float2:
case BasicType::Float3:
case BasicType::Float4:
case BasicType::Int2:
case BasicType::Int3:
case BasicType::Int4:
{
if (leftType != rightType && rightType != ShaderNodes::Node::GetComponentType(leftType))
if (leftType != rightType && rightType != GetComponentType(leftType))
throw AstError{ "Left expression type is not compatible with right expression type" };
break;
}
case ShaderNodes::BasicType::Mat4x4:
case BasicType::Mat4x4:
{
switch (rightType)
{
case ShaderNodes::BasicType::Float1:
case ShaderNodes::BasicType::Float4:
case ShaderNodes::BasicType::Mat4x4:
case BasicType::Float1:
case BasicType::Float4:
case BasicType::Mat4x4:
break;
default:
@ -211,120 +238,86 @@ namespace Nz
default:
TypeMustMatch(node.left, node.right);
break;
}
}
}
ShaderAstRecursiveVisitor::Visit(node);
}
void ShaderAstValidator::Visit(ShaderNodes::Branch& node)
void AstValidator::Visit(CastExpression& node)
{
for (const auto& condStatement : node.condStatements)
{
const ShaderExpressionType& condType = MandatoryExpr(condStatement.condition)->GetExpressionType();
if (!IsBasicType(condType) || std::get<ShaderNodes::BasicType>(condType) != ShaderNodes::BasicType::Boolean)
throw AstError{ "if expression must resolve to boolean type" };
RegisterScope(node);
MandatoryNode(condStatement.statement);
}
ShaderAstRecursiveVisitor::Visit(node);
}
void ShaderAstValidator::Visit(ShaderNodes::Cast& node)
{
unsigned int componentCount = 0;
unsigned int requiredComponents = node.GetComponentCount(node.exprType);
for (const auto& exprPtr : node.expressions)
unsigned int requiredComponents = GetComponentCount(node.targetType);
for (auto& exprPtr : node.expressions)
{
if (!exprPtr)
break;
const ShaderExpressionType& exprType = exprPtr->GetExpressionType();
ShaderExpressionType exprType = GetExpressionType(*exprPtr, m_context->cache);
if (!IsBasicType(exprType))
throw AstError{ "incompatible type" };
componentCount += node.GetComponentCount(std::get<ShaderNodes::BasicType>(exprType));
componentCount += GetComponentCount(std::get<BasicType>(exprType));
}
if (componentCount != requiredComponents)
throw AstError{ "component count doesn't match required component count" };
ShaderAstRecursiveVisitor::Visit(node);
AstRecursiveVisitor::Visit(node);
}
void ShaderAstValidator::Visit(ShaderNodes::ConditionalExpression& node)
void AstValidator::Visit(ConditionalExpression& node)
{
MandatoryNode(node.truePath);
MandatoryNode(node.falsePath);
MandatoryExpr(node.truePath);
MandatoryExpr(node.falsePath);
if (m_shader.FindConditionByName(node.conditionName) == ShaderAst::InvalidCondition)
throw AstError{ "condition not found" };
RegisterScope(node);
AstRecursiveVisitor::Visit(node);
//if (m_shader.FindConditionByName(node.conditionName) == ShaderAst::InvalidCondition)
// throw AstError{ "condition not found" };
}
void ShaderAstValidator::Visit(ShaderNodes::ConditionalStatement& node)
void AstValidator::Visit(ConstantExpression& node)
{
MandatoryNode(node.statement);
if (m_shader.FindConditionByName(node.conditionName) == ShaderAst::InvalidCondition)
throw AstError{ "condition not found" };
RegisterScope(node);
}
void ShaderAstValidator::Visit(ShaderNodes::Constant& /*node*/)
{
}
void ShaderAstValidator::Visit(ShaderNodes::DeclareVariable& node)
void AstValidator::Visit(IdentifierExpression& node)
{
assert(m_context);
if (node.variable->GetType() != ShaderNodes::VariableType::LocalVariable)
throw AstError{ "Only local variables can be declared in a statement" };
if (!m_context->activeScopeId)
throw AstError{ "no scope" };
const auto& localVar = static_cast<const ShaderNodes::LocalVariable&>(*node.variable);
RegisterScope(node);
auto& local = m_context->declaredLocals.emplace_back();
local.name = localVar.name;
local.type = localVar.type;
ShaderAstRecursiveVisitor::Visit(node);
const AstCache::Identifier* identifier = m_context->cache->FindIdentifier(*m_context->activeScopeId, node.identifier);
if (!identifier)
throw AstError{ "Unknown variable " + node.identifier };
}
void ShaderAstValidator::Visit(ShaderNodes::ExpressionStatement& node)
void AstValidator::Visit(IntrinsicExpression& node)
{
MandatoryNode(node.expression);
RegisterScope(node);
ShaderAstRecursiveVisitor::Visit(node);
}
void ShaderAstValidator::Visit(ShaderNodes::Identifier& node)
{
assert(m_context);
if (!node.var)
throw AstError{ "Invalid variable" };
Visit(node.var);
}
void ShaderAstValidator::Visit(ShaderNodes::IntrinsicCall& node)
{
switch (node.intrinsic)
{
case ShaderNodes::IntrinsicType::CrossProduct:
case ShaderNodes::IntrinsicType::DotProduct:
case IntrinsicType::CrossProduct:
case IntrinsicType::DotProduct:
{
if (node.parameters.size() != 2)
throw AstError { "Expected 2 parameters" };
for (auto& param : node.parameters)
MandatoryNode(param);
MandatoryExpr(param);
ShaderExpressionType type = node.parameters.front()->GetExpressionType();
ShaderExpressionType type = GetExpressionType(*node.parameters.front(), m_context->cache);
for (std::size_t i = 1; i < node.parameters.size(); ++i)
{
if (type != node.parameters[i]->GetExpressionType())
if (type != GetExpressionType(MandatoryExpr(node.parameters[i])), m_context->cache)
throw AstError{ "All type must match" };
}
@ -334,180 +327,176 @@ namespace Nz
switch (node.intrinsic)
{
case ShaderNodes::IntrinsicType::CrossProduct:
case IntrinsicType::CrossProduct:
{
if (node.parameters[0]->GetExpressionType() != ShaderExpressionType{ ShaderNodes::BasicType::Float3 })
if (GetExpressionType(*node.parameters[0]) != ShaderExpressionType{ BasicType::Float3 }, m_context->cache)
throw AstError{ "CrossProduct only works with Float3 expressions" };
break;
}
case ShaderNodes::IntrinsicType::DotProduct:
case IntrinsicType::DotProduct:
break;
}
ShaderAstRecursiveVisitor::Visit(node);
AstRecursiveVisitor::Visit(node);
}
void ShaderAstValidator::Visit(ShaderNodes::ReturnStatement& node)
void AstValidator::Visit(SwizzleExpression& node)
{
if (m_context->currentFunction->returnType != ShaderExpressionType(ShaderNodes::BasicType::Void))
{
if (MandatoryExpr(node.returnExpr)->GetExpressionType() != m_context->currentFunction->returnType)
throw AstError{ "Return type doesn't match function return type" };
}
else
{
if (node.returnExpr)
throw AstError{ "Unexpected expression for return (function doesn't return)" };
}
RegisterScope(node);
ShaderAstRecursiveVisitor::Visit(node);
}
void ShaderAstValidator::Visit(ShaderNodes::Sample2D& node)
{
if (MandatoryExpr(node.sampler)->GetExpressionType() != ShaderExpressionType{ ShaderNodes::BasicType::Sampler2D })
throw AstError{ "Sampler must be a Sampler2D" };
if (MandatoryExpr(node.coordinates)->GetExpressionType() != ShaderExpressionType{ ShaderNodes::BasicType::Float2 })
throw AstError{ "Coordinates must be a Float2" };
ShaderAstRecursiveVisitor::Visit(node);
}
void ShaderAstValidator::Visit(ShaderNodes::StatementBlock& node)
{
assert(m_context);
m_context->blockLocalIndex.push_back(m_context->declaredLocals.size());
for (const auto& statement : node.statements)
MandatoryNode(statement);
assert(m_context->declaredLocals.size() >= m_context->blockLocalIndex.back());
m_context->declaredLocals.resize(m_context->blockLocalIndex.back());
m_context->blockLocalIndex.pop_back();
ShaderAstRecursiveVisitor::Visit(node);
}
void ShaderAstValidator::Visit(ShaderNodes::SwizzleOp& node)
{
if (node.componentCount > 4)
throw AstError{ "Cannot swizzle more than four elements" };
const ShaderExpressionType& exprType = MandatoryExpr(node.expression)->GetExpressionType();
const ShaderExpressionType& exprType = GetExpressionType(MandatoryExpr(node.expression), m_context->cache);
if (!IsBasicType(exprType))
throw AstError{ "Cannot swizzle this type" };
switch (std::get<ShaderNodes::BasicType>(exprType))
switch (std::get<BasicType>(exprType))
{
case ShaderNodes::BasicType::Float1:
case ShaderNodes::BasicType::Float2:
case ShaderNodes::BasicType::Float3:
case ShaderNodes::BasicType::Float4:
case ShaderNodes::BasicType::Int1:
case ShaderNodes::BasicType::Int2:
case ShaderNodes::BasicType::Int3:
case ShaderNodes::BasicType::Int4:
case BasicType::Float1:
case BasicType::Float2:
case BasicType::Float3:
case BasicType::Float4:
case BasicType::Int1:
case BasicType::Int2:
case BasicType::Int3:
case BasicType::Int4:
break;
default:
throw AstError{ "Cannot swizzle this type" };
}
ShaderAstRecursiveVisitor::Visit(node);
AstRecursiveVisitor::Visit(node);
}
void ShaderAstValidator::Visit(ShaderNodes::BuiltinVariable& var)
void AstValidator::Visit(BranchStatement& node)
{
switch (var.entry)
RegisterScope(node);
for (auto& condStatement : node.condStatements)
{
case ShaderNodes::BuiltinEntry::VertexPosition:
if (!IsBasicType(var.type) ||
std::get<ShaderNodes::BasicType>(var.type) != ShaderNodes::BasicType::Float4)
throw AstError{ "Builtin is not of the expected type" };
const ShaderExpressionType& condType = GetExpressionType(MandatoryExpr(condStatement.condition), m_context->cache);
if (!IsBasicType(condType) || std::get<BasicType>(condType) != BasicType::Boolean)
throw AstError{ "if expression must resolve to boolean type" };
break;
default:
break;
}
}
void ShaderAstValidator::Visit(ShaderNodes::InputVariable& var)
{
for (std::size_t i = 0; i < m_shader.GetInputCount(); ++i)
{
const auto& input = m_shader.GetInput(i);
if (input.name == var.name)
{
TypeMustMatch(input.type, var.type);
return;
}
MandatoryStatement(condStatement.statement);
}
throw AstError{ "Input not found" };
AstRecursiveVisitor::Visit(node);
}
void ShaderAstValidator::Visit(ShaderNodes::LocalVariable& var)
void AstValidator::Visit(ConditionalStatement& node)
{
const auto& vars = m_context->declaredLocals;
MandatoryStatement(node.statement);
auto it = std::find_if(vars.begin(), vars.end(), [&](const auto& v) { return v.name == var.name; });
if (it == vars.end())
throw AstError{ "Local variable not found in this block" };
RegisterScope(node);
TypeMustMatch(it->type, var.type);
AstRecursiveVisitor::Visit(node);
//if (m_shader.FindConditionByName(node.conditionName) == ShaderAst::InvalidCondition)
// throw AstError{ "condition not found" };
}
void ShaderAstValidator::Visit(ShaderNodes::OutputVariable& var)
void AstValidator::Visit(DeclareFunctionStatement& node)
{
for (std::size_t i = 0; i < m_shader.GetOutputCount(); ++i)
auto& scope = EnterScope();
RegisterScope(node);
for (auto& parameter : node.parameters)
{
const auto& input = m_shader.GetOutput(i);
if (input.name == var.name)
{
TypeMustMatch(input.type, var.type);
return;
}
auto& identifier = scope.identifiers.emplace_back();
identifier = AstCache::Identifier{ parameter.name, AstCache::Variable { parameter.type } };
}
throw AstError{ "Output not found" };
for (auto& statement : node.statements)
MandatoryStatement(statement).Visit(*this);
ExitScope();
}
void ShaderAstValidator::Visit(ShaderNodes::ParameterVariable& var)
void AstValidator::Visit(DeclareStructStatement& node)
{
assert(m_context->currentFunction);
assert(m_context);
const auto& parameters = m_context->currentFunction->parameters;
if (!m_context->activeScopeId)
throw AstError{ "cannot declare variable without scope" };
auto it = std::find_if(parameters.begin(), parameters.end(), [&](const auto& parameter) { return parameter.name == var.name; });
if (it == parameters.end())
throw AstError{ "Parameter not found in function" };
RegisterScope(node);
TypeMustMatch(it->type, var.type);
auto& scope = m_context->cache->scopes[*m_context->activeScopeId];
auto& identifier = scope.identifiers.emplace_back();
identifier = AstCache::Identifier{ node.description.name, node.description };
AstRecursiveVisitor::Visit(node);
}
void ShaderAstValidator::Visit(ShaderNodes::UniformVariable& var)
void AstValidator::Visit(DeclareVariableStatement& node)
{
for (std::size_t i = 0; i < m_shader.GetUniformCount(); ++i)
assert(m_context);
if (!m_context->activeScopeId)
throw AstError{ "cannot declare variable without scope" };
RegisterScope(node);
auto& scope = m_context->cache->scopes[*m_context->activeScopeId];
auto& identifier = scope.identifiers.emplace_back();
identifier = AstCache::Identifier{ node.varName, AstCache::Variable { node.varType } };
AstRecursiveVisitor::Visit(node);
}
void AstValidator::Visit(ExpressionStatement& node)
{
RegisterScope(node);
MandatoryExpr(node.expression);
AstRecursiveVisitor::Visit(node);
}
void AstValidator::Visit(MultiStatement& node)
{
assert(m_context);
EnterScope();
RegisterScope(node);
for (auto& statement : node.statements)
MandatoryStatement(statement);
ExitScope();
AstRecursiveVisitor::Visit(node);
}
void AstValidator::Visit(ReturnStatement& node)
{
RegisterScope(node);
/*if (m_context->currentFunction->returnType != ShaderExpressionType(BasicType::Void))
{
const auto& uniform = m_shader.GetUniform(i);
if (uniform.name == var.name)
{
TypeMustMatch(uniform.type, var.type);
return;
}
if (GetExpressionType(MandatoryExpr(node.returnExpr)) != m_context->currentFunction->returnType)
throw AstError{ "Return type doesn't match function return type" };
}
else
{
if (node.returnExpr)
throw AstError{ "Unexpected expression for return (function doesn't return)" };
}*/
throw AstError{ "Uniform not found" };
AstRecursiveVisitor::Visit(node);
}
bool ValidateShader(const ShaderAst& shader, std::string* error)
bool ValidateAst(StatementPtr& node, std::string* error, AstCache* cache)
{
ShaderAstValidator validator(shader);
return validator.Validate(error);
AstValidator validator;
return validator.Validate(node, error, cache);
}
}

View File

@ -1,100 +0,0 @@
// Copyright (C) 2020 Jérôme Leclercq
// This file is part of the "Nazara Engine - Shader generator"
// For conditions of distribution and use, see copyright notice in Config.hpp
#include <Nazara/Shader/ShaderAstVisitorExcept.hpp>
#include <stdexcept>
#include <Nazara/Shader/Debug.hpp>
namespace Nz
{
void ShaderAstVisitorExcept::Visit(ShaderNodes::AccessMember& /*node*/)
{
throw std::runtime_error("unhandled AccessMember node");
}
void ShaderAstVisitorExcept::Visit(ShaderNodes::AssignOp& /*node*/)
{
throw std::runtime_error("unhandled AssignOp node");
}
void ShaderAstVisitorExcept::Visit(ShaderNodes::BinaryOp& /*node*/)
{
throw std::runtime_error("unhandled AccessMember node");
}
void ShaderAstVisitorExcept::Visit(ShaderNodes::Branch& /*node*/)
{
throw std::runtime_error("unhandled Branch node");
}
void ShaderAstVisitorExcept::Visit(ShaderNodes::Cast& /*node*/)
{
throw std::runtime_error("unhandled Cast node");
}
void ShaderAstVisitorExcept::Visit(ShaderNodes::ConditionalExpression& /*node*/)
{
throw std::runtime_error("unhandled ConditionalExpression node");
}
void ShaderAstVisitorExcept::Visit(ShaderNodes::ConditionalStatement& /*node*/)
{
throw std::runtime_error("unhandled ConditionalStatement node");
}
void ShaderAstVisitorExcept::Visit(ShaderNodes::Constant& /*node*/)
{
throw std::runtime_error("unhandled Constant node");
}
void ShaderAstVisitorExcept::Visit(ShaderNodes::DeclareVariable& /*node*/)
{
throw std::runtime_error("unhandled DeclareVariable node");
}
void ShaderAstVisitorExcept::Visit(ShaderNodes::Discard& /*node*/)
{
throw std::runtime_error("unhandled Discard node");
}
void ShaderAstVisitorExcept::Visit(ShaderNodes::ExpressionStatement& /*node*/)
{
throw std::runtime_error("unhandled ExpressionStatement node");
}
void ShaderAstVisitorExcept::Visit(ShaderNodes::Identifier& /*node*/)
{
throw std::runtime_error("unhandled Identifier node");
}
void ShaderAstVisitorExcept::Visit(ShaderNodes::IntrinsicCall& /*node*/)
{
throw std::runtime_error("unhandled IntrinsicCall node");
}
void ShaderAstVisitorExcept::Visit(ShaderNodes::NoOp& node)
{
throw std::runtime_error("unhandled NoOp node");
}
void ShaderAstVisitorExcept::Visit(ShaderNodes::ReturnStatement& node)
{
throw std::runtime_error("unhandled ReturnStatement node");
}
void ShaderAstVisitorExcept::Visit(ShaderNodes::Sample2D& /*node*/)
{
throw std::runtime_error("unhandled Sample2D node");
}
void ShaderAstVisitorExcept::Visit(ShaderNodes::StatementBlock& /*node*/)
{
throw std::runtime_error("unhandled StatementBlock node");
}
void ShaderAstVisitorExcept::Visit(ShaderNodes::SwizzleOp& /*node*/)
{
throw std::runtime_error("unhandled SwizzleOp node");
}
}

View File

@ -42,6 +42,7 @@ namespace Nz::ShaderLang
std::unordered_map<std::string, TokenType> reservedKeywords = {
{ "false", TokenType::BoolFalse },
{ "fn", TokenType::FunctionDeclaration },
{ "let", TokenType::Let },
{ "return", TokenType::Return },
{ "true", TokenType::BoolTrue }
};
@ -143,7 +144,7 @@ namespace Nz::ShaderLang
while (next != -1);
}
else
tokenType == TokenType::Divide;
tokenType = TokenType::Divide;
break;
}
@ -191,9 +192,11 @@ namespace Nz::ShaderLang
std::string valueStr(str.substr(start, currentPos - start + 1));
const char* ptr = valueStr.c_str();
char* end;
double value = std::strtod(valueStr.c_str(), &end);
if (end != &str[currentPos + 1])
double value = std::strtod(ptr, &end);
if (end != &ptr[valueStr.size()])
throw BadNumber{};
token.data = value;
@ -218,6 +221,7 @@ namespace Nz::ShaderLang
break;
}
case '=': tokenType = TokenType::Assign; break;
case '+': tokenType = TokenType::Plus; break;
case '*': tokenType = TokenType::Multiply; break;
case ':': tokenType = TokenType::Colon; break;

View File

@ -3,6 +3,7 @@
// For conditions of distribution and use, see copyright notice in Config.hpp
#include <Nazara/Shader/ShaderLangParser.hpp>
#include <Nazara/Shader/ShaderBuilder.hpp>
#include <cassert>
#include <Nazara/Shader/Debug.hpp>
@ -10,36 +11,38 @@ namespace Nz::ShaderLang
{
namespace
{
std::unordered_map<std::string, ShaderNodes::BasicType> identifierToBasicType = {
{ "bool", ShaderNodes::BasicType::Boolean },
std::unordered_map<std::string, ShaderAst::BasicType> identifierToBasicType = {
{ "bool", ShaderAst::BasicType::Boolean },
{ "i32", ShaderNodes::BasicType::Int1 },
{ "vec2i32", ShaderNodes::BasicType::Int2 },
{ "vec3i32", ShaderNodes::BasicType::Int3 },
{ "vec4i32", ShaderNodes::BasicType::Int4 },
{ "i32", ShaderAst::BasicType::Int1 },
{ "vec2i32", ShaderAst::BasicType::Int2 },
{ "vec3i32", ShaderAst::BasicType::Int3 },
{ "vec4i32", ShaderAst::BasicType::Int4 },
{ "f32", ShaderNodes::BasicType::Float1 },
{ "vec2f32", ShaderNodes::BasicType::Float2 },
{ "vec3f32", ShaderNodes::BasicType::Float3 },
{ "vec4f32", ShaderNodes::BasicType::Float4 },
{ "f32", ShaderAst::BasicType::Float1 },
{ "vec2f32", ShaderAst::BasicType::Float2 },
{ "vec3f32", ShaderAst::BasicType::Float3 },
{ "vec4f32", ShaderAst::BasicType::Float4 },
{ "mat4x4f32", ShaderNodes::BasicType::Mat4x4 },
{ "sampler2D", ShaderNodes::BasicType::Sampler2D },
{ "void", ShaderNodes::BasicType::Void },
{ "mat4x4f32", ShaderAst::BasicType::Mat4x4 },
{ "sampler2D", ShaderAst::BasicType::Sampler2D },
{ "void", ShaderAst::BasicType::Void },
{ "u32", ShaderNodes::BasicType::UInt1 },
{ "vec2u32", ShaderNodes::BasicType::UInt3 },
{ "vec3u32", ShaderNodes::BasicType::UInt3 },
{ "vec4u32", ShaderNodes::BasicType::UInt4 },
{ "u32", ShaderAst::BasicType::UInt1 },
{ "vec2u32", ShaderAst::BasicType::UInt3 },
{ "vec3u32", ShaderAst::BasicType::UInt3 },
{ "vec4u32", ShaderAst::BasicType::UInt4 },
};
}
ShaderAst Parser::Parse(const std::vector<Token>& tokens)
ShaderAst::StatementPtr Parser::Parse(const std::vector<Token>& tokens)
{
Context context;
context.tokenCount = tokens.size();
context.tokens = tokens.data();
context.root = std::make_unique<ShaderAst::MultiStatement>();
m_context = &context;
m_context->tokenIndex = -1;
@ -51,7 +54,7 @@ namespace Nz::ShaderLang
switch (nextToken.type)
{
case TokenType::FunctionDeclaration:
ParseFunctionDeclaration();
context.root->statements.push_back(ParseFunctionDeclaration());
break;
case TokenType::EndOfStream:
@ -63,7 +66,7 @@ namespace Nz::ShaderLang
}
}
return std::move(context.result);
return std::move(context.root);
}
const Token& Parser::Advance()
@ -92,12 +95,12 @@ namespace Nz::ShaderLang
return m_context->tokens[m_context->tokenIndex + 1];
}
ShaderNodes::StatementPtr Parser::ParseFunctionBody()
std::vector<ShaderAst::StatementPtr> Parser::ParseFunctionBody()
{
return ParseStatementList();
}
void Parser::ParseFunctionDeclaration()
ShaderAst::StatementPtr Parser::ParseFunctionDeclaration()
{
ExpectNext(TokenType::FunctionDeclaration);
@ -105,7 +108,7 @@ namespace Nz::ShaderLang
ExpectNext(TokenType::OpenParenthesis);
std::vector<ShaderAst::FunctionParameter> parameters;
std::vector<ShaderAst::DeclareFunctionStatement::Parameter> parameters;
bool firstParameter = true;
for (;;)
@ -126,7 +129,7 @@ namespace Nz::ShaderLang
ExpectNext(TokenType::ClosingParenthesis);
ShaderExpressionType returnType = ShaderNodes::BasicType::Void;
ShaderAst::ShaderExpressionType returnType = ShaderAst::BasicType::Void;
if (PeekNext().type == TokenType::FunctionReturn)
{
Advance(); //< Consume ->
@ -136,42 +139,46 @@ namespace Nz::ShaderLang
ExpectNext(TokenType::OpenCurlyBracket);
ShaderNodes::StatementPtr functionBody = ParseFunctionBody();
std::vector<ShaderAst::StatementPtr> functionBody = ParseFunctionBody();
ExpectNext(TokenType::ClosingCurlyBracket);
m_context->result.AddFunction(functionName, functionBody, std::move(parameters), returnType);
return ShaderBuilder::DeclareFunction(std::move(functionName), std::move(parameters), std::move(functionBody), std::move(returnType));
}
ShaderAst::FunctionParameter Parser::ParseFunctionParameter()
ShaderAst::DeclareFunctionStatement::Parameter Parser::ParseFunctionParameter()
{
std::string parameterName = ParseIdentifierAsName();
ExpectNext(TokenType::Colon);
ShaderExpressionType parameterType = ParseIdentifierAsType();
ShaderAst::ShaderExpressionType parameterType = ParseIdentifierAsType();
return { parameterName, parameterType };
}
ShaderNodes::StatementPtr Parser::ParseReturnStatement()
ShaderAst::StatementPtr Parser::ParseReturnStatement()
{
ExpectNext(TokenType::Return);
ShaderNodes::ExpressionPtr expr;
ShaderAst::ExpressionPtr expr;
if (PeekNext().type != TokenType::Semicolon)
expr = ParseExpression();
return ShaderNodes::ReturnStatement::Build(std::move(expr));
return ShaderBuilder::Return(std::move(expr));
}
ShaderNodes::StatementPtr Parser::ParseStatement()
ShaderAst::StatementPtr Parser::ParseStatement()
{
const Token& token = PeekNext();
ShaderNodes::StatementPtr statement;
ShaderAst::StatementPtr statement;
switch (token.type)
{
case TokenType::Let:
statement = ParseVariableDeclaration();
break;
case TokenType::Return:
statement = ParseReturnStatement();
break;
@ -185,18 +192,38 @@ namespace Nz::ShaderLang
return statement;
}
ShaderNodes::StatementPtr Parser::ParseStatementList()
std::vector<ShaderAst::StatementPtr> Parser::ParseStatementList()
{
std::vector<ShaderNodes::StatementPtr> statements;
std::vector<ShaderAst::StatementPtr> statements;
while (PeekNext().type != TokenType::ClosingCurlyBracket)
{
statements.push_back(ParseStatement());
}
return ShaderNodes::StatementBlock::Build(std::move(statements));
return statements;
}
ShaderNodes::ExpressionPtr Parser::ParseBinOpRhs(int exprPrecedence, ShaderNodes::ExpressionPtr lhs)
ShaderAst::StatementPtr Parser::ParseVariableDeclaration()
{
ExpectNext(TokenType::Let);
std::string variableName = ParseIdentifierAsName();
ExpectNext(TokenType::Colon);
ShaderAst::ShaderExpressionType variableType = ParseIdentifierAsType();
ShaderAst::ExpressionPtr expression;
if (PeekNext().type == TokenType::Assign)
{
Advance();
expression = ParseExpression();
}
return ShaderBuilder::DeclareVariable(std::move(variableName), std::move(variableType), std::move(expression));
}
ShaderAst::ExpressionPtr Parser::ParseBinOpRhs(int exprPrecedence, ShaderAst::ExpressionPtr lhs)
{
for (;;)
{
@ -207,7 +234,7 @@ namespace Nz::ShaderLang
return lhs;
Advance();
ShaderNodes::ExpressionPtr rhs = ParsePrimaryExpression();
ShaderAst::ExpressionPtr rhs = ParsePrimaryExpression();
const Token& nextOp = PeekNext();
@ -215,57 +242,58 @@ namespace Nz::ShaderLang
if (tokenPrecedence < nextTokenPrecedence)
rhs = ParseBinOpRhs(tokenPrecedence + 1, std::move(rhs));
ShaderNodes::BinaryType binaryType;
ShaderAst::BinaryType binaryType;
{
switch (currentOp.type)
{
case TokenType::Plus: binaryType = ShaderNodes::BinaryType::Add; break;
case TokenType::Minus: binaryType = ShaderNodes::BinaryType::Subtract; break;
case TokenType::Multiply: binaryType = ShaderNodes::BinaryType::Multiply; break;
case TokenType::Divide: binaryType = ShaderNodes::BinaryType::Divide; break;
case TokenType::Plus: binaryType = ShaderAst::BinaryType::Add; break;
case TokenType::Minus: binaryType = ShaderAst::BinaryType::Subtract; break;
case TokenType::Multiply: binaryType = ShaderAst::BinaryType::Multiply; break;
case TokenType::Divide: binaryType = ShaderAst::BinaryType::Divide; break;
default: throw UnexpectedToken{};
}
}
lhs = ShaderNodes::BinaryOp::Build(binaryType, std::move(lhs), std::move(rhs));
lhs = ShaderBuilder::Binary(binaryType, std::move(lhs), std::move(rhs));
}
}
ShaderNodes::ExpressionPtr Parser::ParseExpression()
ShaderAst::ExpressionPtr Parser::ParseExpression()
{
return ParseBinOpRhs(0, ParsePrimaryExpression());
}
ShaderNodes::ExpressionPtr Parser::ParseIdentifier()
ShaderAst::ExpressionPtr Parser::ParseIdentifier()
{
const Token& identifier = ExpectNext(TokenType::Identifier);
return ShaderNodes::Identifier::Build(ShaderNodes::ParameterVariable::Build(std::get<std::string>(identifier.data), ShaderNodes::BasicType::Float3));
return ShaderBuilder::Identifier(std::get<std::string>(identifier.data));
}
ShaderNodes::ExpressionPtr Parser::ParseIntegerExpression()
ShaderAst::ExpressionPtr Parser::ParseIntegerExpression()
{
const Token& integer = ExpectNext(TokenType::IntegerValue);
return ShaderNodes::Constant::Build(static_cast<Nz::Int32>(std::get<long long>(integer.data)));
return ShaderBuilder::Constant(static_cast<Nz::Int32>(std::get<long long>(integer.data)));
}
ShaderNodes::ExpressionPtr Parser::ParseParenthesisExpression()
ShaderAst::ExpressionPtr Parser::ParseParenthesisExpression()
{
ExpectNext(TokenType::OpenParenthesis);
ShaderNodes::ExpressionPtr expression = ParseExpression();
ShaderAst::ExpressionPtr expression = ParseExpression();
ExpectNext(TokenType::ClosingParenthesis);
return expression;
}
ShaderNodes::ExpressionPtr Parser::ParsePrimaryExpression()
ShaderAst::ExpressionPtr Parser::ParsePrimaryExpression()
{
const Token& token = PeekNext();
switch (token.type)
{
case TokenType::BoolFalse: return ShaderNodes::Constant::Build(false);
case TokenType::BoolTrue: return ShaderNodes::Constant::Build(true);
case TokenType::BoolFalse: return ShaderBuilder::Constant(false);
case TokenType::BoolTrue: return ShaderBuilder::Constant(true);
case TokenType::FloatingPointValue: return ShaderBuilder::Constant(float(std::get<double>(Advance().data))); //< FIXME
case TokenType::Identifier: return ParseIdentifier();
case TokenType::IntegerValue: return ParseIntegerExpression();
case TokenType::OpenParenthesis: return ParseParenthesisExpression();
@ -286,7 +314,7 @@ namespace Nz::ShaderLang
return identifier;
}
ShaderExpressionType Parser::ParseIdentifierAsType()
ShaderAst::ShaderExpressionType Parser::ParseIdentifierAsType()
{
const Token& identifier = ExpectNext(TokenType::Identifier);

View File

@ -4,265 +4,29 @@
#include <Nazara/Shader/ShaderNodes.hpp>
#include <Nazara/Core/Algorithm.hpp>
#include <Nazara/Shader/ShaderAstSerializer.hpp>
#include <Nazara/Shader/ShaderAstVisitor.hpp>
#include <Nazara/Shader/ShaderWriter.hpp>
#include <Nazara/Shader/ShaderAstExpressionVisitor.hpp>
#include <Nazara/Shader/ShaderAstStatementVisitor.hpp>
#include <Nazara/Shader/Debug.hpp>
namespace Nz::ShaderNodes
namespace Nz::ShaderAst
{
Node::~Node() = default;
void ExpressionStatement::Visit(ShaderAstVisitor& visitor)
{
visitor.Visit(*this);
#define NAZARA_SHADERAST_NODE(Node) NodeType Node::GetType() const \
{ \
return NodeType:: Node; \
}
#include <Nazara/Shader/ShaderAstNodes.hpp>
#define NAZARA_SHADERAST_EXPRESSION(Node) void Node::Visit(AstExpressionVisitor& visitor) \
{\
visitor.Visit(*this); \
}
void ConditionalStatement::Visit(ShaderAstVisitor& visitor)
{
visitor.Visit(*this);
#define NAZARA_SHADERAST_STATEMENT(Node) void Node::Visit(AstStatementVisitor& visitor) \
{\
visitor.Visit(*this); \
}
void StatementBlock::Visit(ShaderAstVisitor& visitor)
{
visitor.Visit(*this);
}
void DeclareVariable::Visit(ShaderAstVisitor& visitor)
{
visitor.Visit(*this);
}
void Discard::Visit(ShaderAstVisitor& visitor)
{
visitor.Visit(*this);
}
ShaderExpressionType Identifier::GetExpressionType() const
{
assert(var);
return var->type;
}
void Identifier::Visit(ShaderAstVisitor& visitor)
{
visitor.Visit(*this);
}
ShaderExpressionType AccessMember::GetExpressionType() const
{
return exprType;
}
void AccessMember::Visit(ShaderAstVisitor& visitor)
{
visitor.Visit(*this);
}
void NoOp::Visit(ShaderAstVisitor& visitor)
{
visitor.Visit(*this);
}
void ReturnStatement::Visit(ShaderAstVisitor& visitor)
{
visitor.Visit(*this);
}
ShaderExpressionType AssignOp::GetExpressionType() const
{
return left->GetExpressionType();
}
void AssignOp::Visit(ShaderAstVisitor& visitor)
{
visitor.Visit(*this);
}
ShaderExpressionType BinaryOp::GetExpressionType() const
{
std::optional<ShaderExpressionType> exprType;
switch (op)
{
case BinaryType::Add:
case BinaryType::Subtract:
exprType = left->GetExpressionType();
break;
case BinaryType::Divide:
case BinaryType::Multiply:
{
const ShaderExpressionType& leftExprType = left->GetExpressionType();
assert(IsBasicType(leftExprType));
const ShaderExpressionType& rightExprType = right->GetExpressionType();
assert(IsBasicType(rightExprType));
switch (std::get<BasicType>(leftExprType))
{
case BasicType::Boolean:
case BasicType::Float2:
case BasicType::Float3:
case BasicType::Float4:
case BasicType::Int2:
case BasicType::Int3:
case BasicType::Int4:
case BasicType::UInt2:
case BasicType::UInt3:
case BasicType::UInt4:
exprType = leftExprType;
break;
case BasicType::Float1:
case BasicType::Int1:
case BasicType::Mat4x4:
case BasicType::UInt1:
exprType = rightExprType;
break;
case BasicType::Sampler2D:
case BasicType::Void:
break;
}
break;
}
case BinaryType::CompEq:
case BinaryType::CompGe:
case BinaryType::CompGt:
case BinaryType::CompLe:
case BinaryType::CompLt:
case BinaryType::CompNe:
exprType = BasicType::Boolean;
break;
}
NazaraAssert(exprType.has_value(), "Unhandled builtin");
return *exprType;
}
void BinaryOp::Visit(ShaderAstVisitor& visitor)
{
visitor.Visit(*this);
}
void Branch::Visit(ShaderAstVisitor& visitor)
{
visitor.Visit(*this);
}
ShaderExpressionType Constant::GetExpressionType() const
{
return std::visit([&](auto&& arg)
{
using T = std::decay_t<decltype(arg)>;
if constexpr (std::is_same_v<T, bool>)
return ShaderNodes::BasicType::Boolean;
else if constexpr (std::is_same_v<T, float>)
return ShaderNodes::BasicType::Float1;
else if constexpr (std::is_same_v<T, Int32>)
return ShaderNodes::BasicType::Int1;
else if constexpr (std::is_same_v<T, UInt32>)
return ShaderNodes::BasicType::Int1;
else if constexpr (std::is_same_v<T, Vector2f>)
return ShaderNodes::BasicType::Float2;
else if constexpr (std::is_same_v<T, Vector3f>)
return ShaderNodes::BasicType::Float3;
else if constexpr (std::is_same_v<T, Vector4f>)
return ShaderNodes::BasicType::Float4;
else if constexpr (std::is_same_v<T, Vector2i32>)
return ShaderNodes::BasicType::Int2;
else if constexpr (std::is_same_v<T, Vector3i32>)
return ShaderNodes::BasicType::Int3;
else if constexpr (std::is_same_v<T, Vector4i32>)
return ShaderNodes::BasicType::Int4;
else
static_assert(AlwaysFalse<T>::value, "non-exhaustive visitor");
}, value);
}
void Constant::Visit(ShaderAstVisitor& visitor)
{
visitor.Visit(*this);
}
ShaderExpressionType Cast::GetExpressionType() const
{
return exprType;
}
void Cast::Visit(ShaderAstVisitor& visitor)
{
visitor.Visit(*this);
}
ShaderExpressionType ConditionalExpression::GetExpressionType() const
{
assert(truePath->GetExpressionType() == falsePath->GetExpressionType());
return truePath->GetExpressionType();
}
void ConditionalExpression::Visit(ShaderAstVisitor& visitor)
{
visitor.Visit(*this);
}
ShaderExpressionType SwizzleOp::GetExpressionType() const
{
const ShaderExpressionType& exprType = expression->GetExpressionType();
assert(IsBasicType(exprType));
return static_cast<BasicType>(UnderlyingCast(GetComponentType(std::get<BasicType>(exprType))) + componentCount - 1);
}
void SwizzleOp::Visit(ShaderAstVisitor& visitor)
{
visitor.Visit(*this);
}
ShaderExpressionType Sample2D::GetExpressionType() const
{
return BasicType::Float4;
}
void Sample2D::Visit(ShaderAstVisitor& visitor)
{
visitor.Visit(*this);
}
ShaderExpressionType IntrinsicCall::GetExpressionType() const
{
switch (intrinsic)
{
case IntrinsicType::CrossProduct:
return parameters.front()->GetExpressionType();
case IntrinsicType::DotProduct:
return BasicType::Float1;
}
NazaraAssert(false, "Unhandled builtin");
return BasicType::Void;
}
void IntrinsicCall::Visit(ShaderAstVisitor& visitor)
{
visitor.Visit(*this);
}
#include <Nazara/Shader/ShaderAstNodes.hpp>
}

View File

@ -1,40 +0,0 @@
// Copyright (C) 2020 Jérôme Leclercq
// This file is part of the "Nazara Engine - Shader generator"
// For conditions of distribution and use, see copyright notice in Config.hpp
#include <Nazara/Shader/ShaderVarVisitorExcept.hpp>
#include <stdexcept>
#include <Nazara/Shader/Debug.hpp>
namespace Nz
{
void ShaderVarVisitorExcept::Visit(ShaderNodes::BuiltinVariable& /*var*/)
{
throw std::runtime_error("unhandled BuiltinVariable");
}
void ShaderVarVisitorExcept::Visit(ShaderNodes::InputVariable& /*var*/)
{
throw std::runtime_error("unhandled InputVariable");
}
void ShaderVarVisitorExcept::Visit(ShaderNodes::LocalVariable& /*var*/)
{
throw std::runtime_error("unhandled LocalVariable");
}
void ShaderVarVisitorExcept::Visit(ShaderNodes::OutputVariable& /*var*/)
{
throw std::runtime_error("unhandled OutputVariable");
}
void ShaderVarVisitorExcept::Visit(ShaderNodes::ParameterVariable& /*var*/)
{
throw std::runtime_error("unhandled ParameterVariable");
}
void ShaderVarVisitorExcept::Visit(ShaderNodes::UniformVariable& /*var*/)
{
throw std::runtime_error("unhandled UniformVariable");
}
}

View File

@ -1,77 +0,0 @@
// Copyright (C) 2020 Jérôme Leclercq
// This file is part of the "Nazara Engine - Shader generator"
// For conditions of distribution and use, see copyright notice in Config.hpp
#include <Nazara/Shader/ShaderVariables.hpp>
#include <Nazara/Shader/ShaderVarVisitor.hpp>
#include <Nazara/Shader/Debug.hpp>
namespace Nz::ShaderNodes
{
ShaderNodes::Variable::~Variable() = default;
VariableType BuiltinVariable::GetType() const
{
return VariableType::BuiltinVariable;
}
void BuiltinVariable::Visit(ShaderVarVisitor& visitor)
{
visitor.Visit(*this);
}
VariableType InputVariable::GetType() const
{
return VariableType::InputVariable;
}
void InputVariable::Visit(ShaderVarVisitor& visitor)
{
visitor.Visit(*this);
}
VariableType LocalVariable::GetType() const
{
return VariableType::LocalVariable;
}
void LocalVariable::Visit(ShaderVarVisitor& visitor)
{
visitor.Visit(*this);
}
VariableType OutputVariable::GetType() const
{
return VariableType::OutputVariable;
}
void OutputVariable::Visit(ShaderVarVisitor& visitor)
{
visitor.Visit(*this);
}
VariableType ParameterVariable::GetType() const
{
return VariableType::ParameterVariable;
}
void ParameterVariable::Visit(ShaderVarVisitor& visitor)
{
visitor.Visit(*this);
}
VariableType UniformVariable::GetType() const
{
return VariableType::UniformVariable;
}
void UniformVariable::Visit(ShaderVarVisitor& visitor)
{
visitor.Visit(*this);
}
}

View File

@ -4,6 +4,7 @@
#include <Nazara/Shader/SpirvAstVisitor.hpp>
#include <Nazara/Core/StackVector.hpp>
#include <Nazara/Shader/ShaderAstExpressionType.hpp>
#include <Nazara/Shader/SpirvSection.hpp>
#include <Nazara/Shader/SpirvExpressionLoad.hpp>
#include <Nazara/Shader/SpirvExpressionStore.hpp>
@ -12,21 +13,21 @@
namespace Nz
{
UInt32 SpirvAstVisitor::EvaluateExpression(const ShaderNodes::ExpressionPtr& expr)
UInt32 SpirvAstVisitor::EvaluateExpression(ShaderAst::ExpressionPtr& expr)
{
Visit(expr);
expr->Visit(*this);
assert(m_resultIds.size() == 1);
return PopResultId();
}
void SpirvAstVisitor::Visit(ShaderNodes::AccessMember& node)
void SpirvAstVisitor::Visit(ShaderAst::AccessMemberExpression& node)
{
SpirvExpressionLoad accessMemberVisitor(m_writer, *m_currentBlock);
PushResultId(accessMemberVisitor.Evaluate(node));
}
void SpirvAstVisitor::Visit(ShaderNodes::AssignOp& node)
void SpirvAstVisitor::Visit(ShaderAst::AssignExpression& node)
{
UInt32 resultId = EvaluateExpression(node.right);
@ -36,20 +37,20 @@ namespace Nz
PushResultId(resultId);
}
void SpirvAstVisitor::Visit(ShaderNodes::BinaryOp& node)
void SpirvAstVisitor::Visit(ShaderAst::BinaryExpression& node)
{
ShaderExpressionType resultExprType = node.GetExpressionType();
ShaderAst::ShaderExpressionType resultExprType = ShaderAst::GetExpressionType(node);
assert(IsBasicType(resultExprType));
const ShaderExpressionType& leftExprType = node.left->GetExpressionType();
ShaderAst::ShaderExpressionType leftExprType = ShaderAst::GetExpressionType(*node.left);
assert(IsBasicType(leftExprType));
const ShaderExpressionType& rightExprType = node.right->GetExpressionType();
ShaderAst::ShaderExpressionType rightExprType = ShaderAst::GetExpressionType(*node.right);
assert(IsBasicType(rightExprType));
ShaderNodes::BasicType resultType = std::get<ShaderNodes::BasicType>(resultExprType);
ShaderNodes::BasicType leftType = std::get<ShaderNodes::BasicType>(leftExprType);
ShaderNodes::BasicType rightType = std::get<ShaderNodes::BasicType>(rightExprType);
ShaderAst::BasicType resultType = std::get<ShaderAst::BasicType>(resultExprType);
ShaderAst::BasicType leftType = std::get<ShaderAst::BasicType>(leftExprType);
ShaderAst::BasicType rightType = std::get<ShaderAst::BasicType>(rightExprType);
UInt32 leftOperand = EvaluateExpression(node.left);
@ -62,308 +63,308 @@ namespace Nz
{
switch (node.op)
{
case ShaderNodes::BinaryType::Add:
case ShaderAst::BinaryType::Add:
{
switch (leftType)
{
case ShaderNodes::BasicType::Float1:
case ShaderNodes::BasicType::Float2:
case ShaderNodes::BasicType::Float3:
case ShaderNodes::BasicType::Float4:
case ShaderNodes::BasicType::Mat4x4:
case ShaderAst::BasicType::Float1:
case ShaderAst::BasicType::Float2:
case ShaderAst::BasicType::Float3:
case ShaderAst::BasicType::Float4:
case ShaderAst::BasicType::Mat4x4:
return SpirvOp::OpFAdd;
case ShaderNodes::BasicType::Int1:
case ShaderNodes::BasicType::Int2:
case ShaderNodes::BasicType::Int3:
case ShaderNodes::BasicType::Int4:
case ShaderNodes::BasicType::UInt1:
case ShaderNodes::BasicType::UInt2:
case ShaderNodes::BasicType::UInt3:
case ShaderNodes::BasicType::UInt4:
case ShaderAst::BasicType::Int1:
case ShaderAst::BasicType::Int2:
case ShaderAst::BasicType::Int3:
case ShaderAst::BasicType::Int4:
case ShaderAst::BasicType::UInt1:
case ShaderAst::BasicType::UInt2:
case ShaderAst::BasicType::UInt3:
case ShaderAst::BasicType::UInt4:
return SpirvOp::OpIAdd;
case ShaderNodes::BasicType::Boolean:
case ShaderNodes::BasicType::Sampler2D:
case ShaderNodes::BasicType::Void:
case ShaderAst::BasicType::Boolean:
case ShaderAst::BasicType::Sampler2D:
case ShaderAst::BasicType::Void:
break;
}
break;
}
case ShaderNodes::BinaryType::Subtract:
case ShaderAst::BinaryType::Subtract:
{
switch (leftType)
{
case ShaderNodes::BasicType::Float1:
case ShaderNodes::BasicType::Float2:
case ShaderNodes::BasicType::Float3:
case ShaderNodes::BasicType::Float4:
case ShaderNodes::BasicType::Mat4x4:
case ShaderAst::BasicType::Float1:
case ShaderAst::BasicType::Float2:
case ShaderAst::BasicType::Float3:
case ShaderAst::BasicType::Float4:
case ShaderAst::BasicType::Mat4x4:
return SpirvOp::OpFSub;
case ShaderNodes::BasicType::Int1:
case ShaderNodes::BasicType::Int2:
case ShaderNodes::BasicType::Int3:
case ShaderNodes::BasicType::Int4:
case ShaderNodes::BasicType::UInt1:
case ShaderNodes::BasicType::UInt2:
case ShaderNodes::BasicType::UInt3:
case ShaderNodes::BasicType::UInt4:
case ShaderAst::BasicType::Int1:
case ShaderAst::BasicType::Int2:
case ShaderAst::BasicType::Int3:
case ShaderAst::BasicType::Int4:
case ShaderAst::BasicType::UInt1:
case ShaderAst::BasicType::UInt2:
case ShaderAst::BasicType::UInt3:
case ShaderAst::BasicType::UInt4:
return SpirvOp::OpISub;
case ShaderNodes::BasicType::Boolean:
case ShaderNodes::BasicType::Sampler2D:
case ShaderNodes::BasicType::Void:
case ShaderAst::BasicType::Boolean:
case ShaderAst::BasicType::Sampler2D:
case ShaderAst::BasicType::Void:
break;
}
break;
}
case ShaderNodes::BinaryType::Divide:
case ShaderAst::BinaryType::Divide:
{
switch (leftType)
{
case ShaderNodes::BasicType::Float1:
case ShaderNodes::BasicType::Float2:
case ShaderNodes::BasicType::Float3:
case ShaderNodes::BasicType::Float4:
case ShaderNodes::BasicType::Mat4x4:
case ShaderAst::BasicType::Float1:
case ShaderAst::BasicType::Float2:
case ShaderAst::BasicType::Float3:
case ShaderAst::BasicType::Float4:
case ShaderAst::BasicType::Mat4x4:
return SpirvOp::OpFDiv;
case ShaderNodes::BasicType::Int1:
case ShaderNodes::BasicType::Int2:
case ShaderNodes::BasicType::Int3:
case ShaderNodes::BasicType::Int4:
case ShaderAst::BasicType::Int1:
case ShaderAst::BasicType::Int2:
case ShaderAst::BasicType::Int3:
case ShaderAst::BasicType::Int4:
return SpirvOp::OpSDiv;
case ShaderNodes::BasicType::UInt1:
case ShaderNodes::BasicType::UInt2:
case ShaderNodes::BasicType::UInt3:
case ShaderNodes::BasicType::UInt4:
case ShaderAst::BasicType::UInt1:
case ShaderAst::BasicType::UInt2:
case ShaderAst::BasicType::UInt3:
case ShaderAst::BasicType::UInt4:
return SpirvOp::OpUDiv;
case ShaderNodes::BasicType::Boolean:
case ShaderNodes::BasicType::Sampler2D:
case ShaderNodes::BasicType::Void:
case ShaderAst::BasicType::Boolean:
case ShaderAst::BasicType::Sampler2D:
case ShaderAst::BasicType::Void:
break;
}
break;
}
case ShaderNodes::BinaryType::CompEq:
case ShaderAst::BinaryType::CompEq:
{
switch (leftType)
{
case ShaderNodes::BasicType::Boolean:
case ShaderAst::BasicType::Boolean:
return SpirvOp::OpLogicalEqual;
case ShaderNodes::BasicType::Float1:
case ShaderNodes::BasicType::Float2:
case ShaderNodes::BasicType::Float3:
case ShaderNodes::BasicType::Float4:
case ShaderNodes::BasicType::Mat4x4:
case ShaderAst::BasicType::Float1:
case ShaderAst::BasicType::Float2:
case ShaderAst::BasicType::Float3:
case ShaderAst::BasicType::Float4:
case ShaderAst::BasicType::Mat4x4:
return SpirvOp::OpFOrdEqual;
case ShaderNodes::BasicType::Int1:
case ShaderNodes::BasicType::Int2:
case ShaderNodes::BasicType::Int3:
case ShaderNodes::BasicType::Int4:
case ShaderNodes::BasicType::UInt1:
case ShaderNodes::BasicType::UInt2:
case ShaderNodes::BasicType::UInt3:
case ShaderNodes::BasicType::UInt4:
case ShaderAst::BasicType::Int1:
case ShaderAst::BasicType::Int2:
case ShaderAst::BasicType::Int3:
case ShaderAst::BasicType::Int4:
case ShaderAst::BasicType::UInt1:
case ShaderAst::BasicType::UInt2:
case ShaderAst::BasicType::UInt3:
case ShaderAst::BasicType::UInt4:
return SpirvOp::OpIEqual;
case ShaderNodes::BasicType::Sampler2D:
case ShaderNodes::BasicType::Void:
case ShaderAst::BasicType::Sampler2D:
case ShaderAst::BasicType::Void:
break;
}
break;
}
case ShaderNodes::BinaryType::CompGe:
case ShaderAst::BinaryType::CompGe:
{
switch (leftType)
{
case ShaderNodes::BasicType::Float1:
case ShaderNodes::BasicType::Float2:
case ShaderNodes::BasicType::Float3:
case ShaderNodes::BasicType::Float4:
case ShaderNodes::BasicType::Mat4x4:
case ShaderAst::BasicType::Float1:
case ShaderAst::BasicType::Float2:
case ShaderAst::BasicType::Float3:
case ShaderAst::BasicType::Float4:
case ShaderAst::BasicType::Mat4x4:
return SpirvOp::OpFOrdGreaterThan;
case ShaderNodes::BasicType::Int1:
case ShaderNodes::BasicType::Int2:
case ShaderNodes::BasicType::Int3:
case ShaderNodes::BasicType::Int4:
case ShaderAst::BasicType::Int1:
case ShaderAst::BasicType::Int2:
case ShaderAst::BasicType::Int3:
case ShaderAst::BasicType::Int4:
return SpirvOp::OpSGreaterThan;
case ShaderNodes::BasicType::UInt1:
case ShaderNodes::BasicType::UInt2:
case ShaderNodes::BasicType::UInt3:
case ShaderNodes::BasicType::UInt4:
case ShaderAst::BasicType::UInt1:
case ShaderAst::BasicType::UInt2:
case ShaderAst::BasicType::UInt3:
case ShaderAst::BasicType::UInt4:
return SpirvOp::OpUGreaterThan;
case ShaderNodes::BasicType::Boolean:
case ShaderNodes::BasicType::Sampler2D:
case ShaderNodes::BasicType::Void:
case ShaderAst::BasicType::Boolean:
case ShaderAst::BasicType::Sampler2D:
case ShaderAst::BasicType::Void:
break;
}
break;
}
case ShaderNodes::BinaryType::CompGt:
case ShaderAst::BinaryType::CompGt:
{
switch (leftType)
{
case ShaderNodes::BasicType::Float1:
case ShaderNodes::BasicType::Float2:
case ShaderNodes::BasicType::Float3:
case ShaderNodes::BasicType::Float4:
case ShaderNodes::BasicType::Mat4x4:
case ShaderAst::BasicType::Float1:
case ShaderAst::BasicType::Float2:
case ShaderAst::BasicType::Float3:
case ShaderAst::BasicType::Float4:
case ShaderAst::BasicType::Mat4x4:
return SpirvOp::OpFOrdGreaterThanEqual;
case ShaderNodes::BasicType::Int1:
case ShaderNodes::BasicType::Int2:
case ShaderNodes::BasicType::Int3:
case ShaderNodes::BasicType::Int4:
case ShaderAst::BasicType::Int1:
case ShaderAst::BasicType::Int2:
case ShaderAst::BasicType::Int3:
case ShaderAst::BasicType::Int4:
return SpirvOp::OpSGreaterThanEqual;
case ShaderNodes::BasicType::UInt1:
case ShaderNodes::BasicType::UInt2:
case ShaderNodes::BasicType::UInt3:
case ShaderNodes::BasicType::UInt4:
case ShaderAst::BasicType::UInt1:
case ShaderAst::BasicType::UInt2:
case ShaderAst::BasicType::UInt3:
case ShaderAst::BasicType::UInt4:
return SpirvOp::OpUGreaterThanEqual;
case ShaderNodes::BasicType::Boolean:
case ShaderNodes::BasicType::Sampler2D:
case ShaderNodes::BasicType::Void:
case ShaderAst::BasicType::Boolean:
case ShaderAst::BasicType::Sampler2D:
case ShaderAst::BasicType::Void:
break;
}
break;
}
case ShaderNodes::BinaryType::CompLe:
case ShaderAst::BinaryType::CompLe:
{
switch (leftType)
{
case ShaderNodes::BasicType::Float1:
case ShaderNodes::BasicType::Float2:
case ShaderNodes::BasicType::Float3:
case ShaderNodes::BasicType::Float4:
case ShaderNodes::BasicType::Mat4x4:
case ShaderAst::BasicType::Float1:
case ShaderAst::BasicType::Float2:
case ShaderAst::BasicType::Float3:
case ShaderAst::BasicType::Float4:
case ShaderAst::BasicType::Mat4x4:
return SpirvOp::OpFOrdLessThanEqual;
case ShaderNodes::BasicType::Int1:
case ShaderNodes::BasicType::Int2:
case ShaderNodes::BasicType::Int3:
case ShaderNodes::BasicType::Int4:
case ShaderAst::BasicType::Int1:
case ShaderAst::BasicType::Int2:
case ShaderAst::BasicType::Int3:
case ShaderAst::BasicType::Int4:
return SpirvOp::OpSLessThanEqual;
case ShaderNodes::BasicType::UInt1:
case ShaderNodes::BasicType::UInt2:
case ShaderNodes::BasicType::UInt3:
case ShaderNodes::BasicType::UInt4:
case ShaderAst::BasicType::UInt1:
case ShaderAst::BasicType::UInt2:
case ShaderAst::BasicType::UInt3:
case ShaderAst::BasicType::UInt4:
return SpirvOp::OpULessThanEqual;
case ShaderNodes::BasicType::Boolean:
case ShaderNodes::BasicType::Sampler2D:
case ShaderNodes::BasicType::Void:
case ShaderAst::BasicType::Boolean:
case ShaderAst::BasicType::Sampler2D:
case ShaderAst::BasicType::Void:
break;
}
break;
}
case ShaderNodes::BinaryType::CompLt:
case ShaderAst::BinaryType::CompLt:
{
switch (leftType)
{
case ShaderNodes::BasicType::Float1:
case ShaderNodes::BasicType::Float2:
case ShaderNodes::BasicType::Float3:
case ShaderNodes::BasicType::Float4:
case ShaderNodes::BasicType::Mat4x4:
case ShaderAst::BasicType::Float1:
case ShaderAst::BasicType::Float2:
case ShaderAst::BasicType::Float3:
case ShaderAst::BasicType::Float4:
case ShaderAst::BasicType::Mat4x4:
return SpirvOp::OpFOrdLessThan;
case ShaderNodes::BasicType::Int1:
case ShaderNodes::BasicType::Int2:
case ShaderNodes::BasicType::Int3:
case ShaderNodes::BasicType::Int4:
case ShaderAst::BasicType::Int1:
case ShaderAst::BasicType::Int2:
case ShaderAst::BasicType::Int3:
case ShaderAst::BasicType::Int4:
return SpirvOp::OpSLessThan;
case ShaderNodes::BasicType::UInt1:
case ShaderNodes::BasicType::UInt2:
case ShaderNodes::BasicType::UInt3:
case ShaderNodes::BasicType::UInt4:
case ShaderAst::BasicType::UInt1:
case ShaderAst::BasicType::UInt2:
case ShaderAst::BasicType::UInt3:
case ShaderAst::BasicType::UInt4:
return SpirvOp::OpULessThan;
case ShaderNodes::BasicType::Boolean:
case ShaderNodes::BasicType::Sampler2D:
case ShaderNodes::BasicType::Void:
case ShaderAst::BasicType::Boolean:
case ShaderAst::BasicType::Sampler2D:
case ShaderAst::BasicType::Void:
break;
}
break;
}
case ShaderNodes::BinaryType::CompNe:
case ShaderAst::BinaryType::CompNe:
{
switch (leftType)
{
case ShaderNodes::BasicType::Boolean:
case ShaderAst::BasicType::Boolean:
return SpirvOp::OpLogicalNotEqual;
case ShaderNodes::BasicType::Float1:
case ShaderNodes::BasicType::Float2:
case ShaderNodes::BasicType::Float3:
case ShaderNodes::BasicType::Float4:
case ShaderNodes::BasicType::Mat4x4:
case ShaderAst::BasicType::Float1:
case ShaderAst::BasicType::Float2:
case ShaderAst::BasicType::Float3:
case ShaderAst::BasicType::Float4:
case ShaderAst::BasicType::Mat4x4:
return SpirvOp::OpFOrdNotEqual;
case ShaderNodes::BasicType::Int1:
case ShaderNodes::BasicType::Int2:
case ShaderNodes::BasicType::Int3:
case ShaderNodes::BasicType::Int4:
case ShaderNodes::BasicType::UInt1:
case ShaderNodes::BasicType::UInt2:
case ShaderNodes::BasicType::UInt3:
case ShaderNodes::BasicType::UInt4:
case ShaderAst::BasicType::Int1:
case ShaderAst::BasicType::Int2:
case ShaderAst::BasicType::Int3:
case ShaderAst::BasicType::Int4:
case ShaderAst::BasicType::UInt1:
case ShaderAst::BasicType::UInt2:
case ShaderAst::BasicType::UInt3:
case ShaderAst::BasicType::UInt4:
return SpirvOp::OpINotEqual;
case ShaderNodes::BasicType::Sampler2D:
case ShaderNodes::BasicType::Void:
case ShaderAst::BasicType::Sampler2D:
case ShaderAst::BasicType::Void:
break;
}
break;
}
case ShaderNodes::BinaryType::Multiply:
case ShaderAst::BinaryType::Multiply:
{
switch (leftType)
{
case ShaderNodes::BasicType::Float1:
case ShaderAst::BasicType::Float1:
{
switch (rightType)
{
case ShaderNodes::BasicType::Float1:
case ShaderAst::BasicType::Float1:
return SpirvOp::OpFMul;
case ShaderNodes::BasicType::Float2:
case ShaderNodes::BasicType::Float3:
case ShaderNodes::BasicType::Float4:
case ShaderAst::BasicType::Float2:
case ShaderAst::BasicType::Float3:
case ShaderAst::BasicType::Float4:
swapOperands = true;
return SpirvOp::OpVectorTimesScalar;
case ShaderNodes::BasicType::Mat4x4:
case ShaderAst::BasicType::Mat4x4:
swapOperands = true;
return SpirvOp::OpMatrixTimesScalar;
@ -374,21 +375,21 @@ namespace Nz
break;
}
case ShaderNodes::BasicType::Float2:
case ShaderNodes::BasicType::Float3:
case ShaderNodes::BasicType::Float4:
case ShaderAst::BasicType::Float2:
case ShaderAst::BasicType::Float3:
case ShaderAst::BasicType::Float4:
{
switch (rightType)
{
case ShaderNodes::BasicType::Float1:
case ShaderAst::BasicType::Float1:
return SpirvOp::OpVectorTimesScalar;
case ShaderNodes::BasicType::Float2:
case ShaderNodes::BasicType::Float3:
case ShaderNodes::BasicType::Float4:
case ShaderAst::BasicType::Float2:
case ShaderAst::BasicType::Float3:
case ShaderAst::BasicType::Float4:
return SpirvOp::OpFMul;
case ShaderNodes::BasicType::Mat4x4:
case ShaderAst::BasicType::Mat4x4:
return SpirvOp::OpVectorTimesMatrix;
default:
@ -398,23 +399,23 @@ namespace Nz
break;
}
case ShaderNodes::BasicType::Int1:
case ShaderNodes::BasicType::Int2:
case ShaderNodes::BasicType::Int3:
case ShaderNodes::BasicType::Int4:
case ShaderNodes::BasicType::UInt1:
case ShaderNodes::BasicType::UInt2:
case ShaderNodes::BasicType::UInt3:
case ShaderNodes::BasicType::UInt4:
case ShaderAst::BasicType::Int1:
case ShaderAst::BasicType::Int2:
case ShaderAst::BasicType::Int3:
case ShaderAst::BasicType::Int4:
case ShaderAst::BasicType::UInt1:
case ShaderAst::BasicType::UInt2:
case ShaderAst::BasicType::UInt3:
case ShaderAst::BasicType::UInt4:
return SpirvOp::OpIMul;
case ShaderNodes::BasicType::Mat4x4:
case ShaderAst::BasicType::Mat4x4:
{
switch (rightType)
{
case ShaderNodes::BasicType::Float1: return SpirvOp::OpMatrixTimesScalar;
case ShaderNodes::BasicType::Float4: return SpirvOp::OpMatrixTimesVector;
case ShaderNodes::BasicType::Mat4x4: return SpirvOp::OpMatrixTimesMatrix;
case ShaderAst::BasicType::Float1: return SpirvOp::OpMatrixTimesScalar;
case ShaderAst::BasicType::Float4: return SpirvOp::OpMatrixTimesVector;
case ShaderAst::BasicType::Mat4x4: return SpirvOp::OpMatrixTimesMatrix;
default:
break;
@ -442,7 +443,7 @@ namespace Nz
PushResultId(resultId);
}
void SpirvAstVisitor::Visit(ShaderNodes::Branch& node)
void SpirvAstVisitor::Visit(ShaderAst::BranchStatement& node)
{
assert(!node.condStatements.empty());
auto& firstCond = node.condStatements.front();
@ -450,7 +451,8 @@ namespace Nz
UInt32 previousConditionId = EvaluateExpression(firstCond.condition);
SpirvBlock previousContentBlock(m_writer);
m_currentBlock = &previousContentBlock;
Visit(firstCond.statement);
firstCond.statement->Visit(*this);
SpirvBlock mergeBlock(m_writer);
m_blocks.back().Append(SpirvOp::OpSelectionMerge, mergeBlock.GetLabelId(), SpirvSelectionControl::None);
@ -458,7 +460,7 @@ namespace Nz
std::optional<std::size_t> nextBlock;
for (std::size_t statementIndex = 1; statementIndex < node.condStatements.size(); ++statementIndex)
{
const auto& statement = node.condStatements[statementIndex];
auto& statement = node.condStatements[statementIndex];
SpirvBlock contentBlock(m_writer);
@ -469,7 +471,8 @@ namespace Nz
previousContentBlock = std::move(contentBlock);
m_currentBlock = &previousContentBlock;
Visit(statement.statement);
statement.statement->Visit(*this);
}
if (node.elseStatement)
@ -477,7 +480,7 @@ namespace Nz
SpirvBlock elseBlock(m_writer);
m_currentBlock = &elseBlock;
Visit(node.elseStatement);
node.elseStatement->Visit(*this);
elseBlock.Append(SpirvOp::OpBranch, mergeBlock.GetLabelId()); //< FIXME: Shouldn't terminate twice
@ -496,16 +499,16 @@ namespace Nz
m_currentBlock = &m_blocks.back();
}
void SpirvAstVisitor::Visit(ShaderNodes::Cast& node)
void SpirvAstVisitor::Visit(ShaderAst::CastExpression& node)
{
const ShaderExpressionType& targetExprType = node.exprType;
const ShaderAst::ShaderExpressionType& targetExprType = node.targetType;
assert(IsBasicType(targetExprType));
ShaderNodes::BasicType targetType = std::get<ShaderNodes::BasicType>(targetExprType);
ShaderAst::BasicType targetType = std::get<ShaderAst::BasicType>(targetExprType);
StackVector<UInt32> exprResults = NazaraStackVector(UInt32, node.expressions.size());
for (const auto& exprPtr : node.expressions)
for (auto& exprPtr : node.expressions)
{
if (!exprPtr)
break;
@ -527,21 +530,21 @@ namespace Nz
PushResultId(resultId);
}
void SpirvAstVisitor::Visit(ShaderNodes::ConditionalExpression& node)
void SpirvAstVisitor::Visit(ShaderAst::ConditionalExpression& node)
{
if (m_writer.IsConditionEnabled(node.conditionName))
Visit(node.truePath);
node.truePath->Visit(*this);
else
Visit(node.falsePath);
node.falsePath->Visit(*this);
}
void SpirvAstVisitor::Visit(ShaderNodes::ConditionalStatement& node)
void SpirvAstVisitor::Visit(ShaderAst::ConditionalStatement& node)
{
if (m_writer.IsConditionEnabled(node.conditionName))
Visit(node.statement);
node.statement->Visit(*this);
}
void SpirvAstVisitor::Visit(ShaderNodes::Constant& node)
void SpirvAstVisitor::Visit(ShaderAst::ConstantExpression& node)
{
std::visit([&] (const auto& value)
{
@ -549,46 +552,42 @@ namespace Nz
}, node.value);
}
void SpirvAstVisitor::Visit(ShaderNodes::DeclareVariable& node)
void SpirvAstVisitor::Visit(ShaderAst::DeclareVariableStatement& node)
{
if (node.expression)
{
assert(node.variable->GetType() == ShaderNodes::VariableType::LocalVariable);
const auto& localVar = static_cast<const ShaderNodes::LocalVariable&>(*node.variable);
m_writer.WriteLocalVariable(localVar.name, EvaluateExpression(node.expression));
}
if (node.initialExpression)
m_writer.WriteLocalVariable(node.varName, EvaluateExpression(node.initialExpression));
}
void SpirvAstVisitor::Visit(ShaderNodes::Discard& /*node*/)
void SpirvAstVisitor::Visit(ShaderAst::DiscardStatement& /*node*/)
{
m_currentBlock->Append(SpirvOp::OpKill);
}
void SpirvAstVisitor::Visit(ShaderNodes::ExpressionStatement& node)
void SpirvAstVisitor::Visit(ShaderAst::ExpressionStatement& node)
{
Visit(node.expression);
node.expression->Visit(*this);
PopResultId();
}
void SpirvAstVisitor::Visit(ShaderNodes::Identifier& node)
void SpirvAstVisitor::Visit(ShaderAst::IdentifierExpression& node)
{
SpirvExpressionLoad loadVisitor(m_writer, *m_currentBlock);
PushResultId(loadVisitor.Evaluate(node));
}
void SpirvAstVisitor::Visit(ShaderNodes::IntrinsicCall& node)
void SpirvAstVisitor::Visit(ShaderAst::IntrinsicExpression& node)
{
switch (node.intrinsic)
{
case ShaderNodes::IntrinsicType::DotProduct:
case ShaderAst::IntrinsicType::DotProduct:
{
const ShaderExpressionType& vecExprType = node.parameters[0]->GetExpressionType();
const ShaderAst::ShaderExpressionType& vecExprType = GetExpressionType(*node.parameters[0]);
assert(IsBasicType(vecExprType));
ShaderNodes::BasicType vecType = std::get<ShaderNodes::BasicType>(vecExprType);
ShaderAst::BasicType vecType = std::get<ShaderAst::BasicType>(vecExprType);
UInt32 typeId = m_writer.GetTypeId(node.GetComponentType(vecType));
UInt32 typeId = m_writer.GetTypeId(ShaderAst::GetComponentType(vecType));
UInt32 vec1 = EvaluateExpression(node.parameters[0]);
UInt32 vec2 = EvaluateExpression(node.parameters[1]);
@ -600,18 +599,18 @@ namespace Nz
break;
}
case ShaderNodes::IntrinsicType::CrossProduct:
case ShaderAst::IntrinsicType::CrossProduct:
default:
throw std::runtime_error("not yet implemented");
}
}
void SpirvAstVisitor::Visit(ShaderNodes::NoOp& /*node*/)
void SpirvAstVisitor::Visit(ShaderAst::NoOpStatement& /*node*/)
{
// nothing to do
}
void SpirvAstVisitor::Visit(ShaderNodes::ReturnStatement& node)
void SpirvAstVisitor::Visit(ShaderAst::ReturnStatement& node)
{
if (node.returnExpr)
m_currentBlock->Append(SpirvOp::OpReturnValue, EvaluateExpression(node.returnExpr));
@ -619,30 +618,18 @@ namespace Nz
m_currentBlock->Append(SpirvOp::OpReturn);
}
void SpirvAstVisitor::Visit(ShaderNodes::Sample2D& node)
{
UInt32 typeId = m_writer.GetTypeId(ShaderNodes::BasicType::Float4);
UInt32 samplerId = EvaluateExpression(node.sampler);
UInt32 coordinatesId = EvaluateExpression(node.coordinates);
UInt32 resultId = m_writer.AllocateResultId();
m_currentBlock->Append(SpirvOp::OpImageSampleImplicitLod, typeId, resultId, samplerId, coordinatesId);
PushResultId(resultId);
}
void SpirvAstVisitor::Visit(ShaderNodes::StatementBlock& node)
void SpirvAstVisitor::Visit(ShaderAst::MultiStatement& node)
{
for (auto& statement : node.statements)
Visit(statement);
statement->Visit(*this);
}
void SpirvAstVisitor::Visit(ShaderNodes::SwizzleOp& node)
void SpirvAstVisitor::Visit(ShaderAst::SwizzleExpression& node)
{
const ShaderExpressionType& targetExprType = node.GetExpressionType();
const ShaderAst::ShaderExpressionType& targetExprType = ShaderAst::GetExpressionType(node);
assert(IsBasicType(targetExprType));
ShaderNodes::BasicType targetType = std::get<ShaderNodes::BasicType>(targetExprType);
ShaderAst::BasicType targetType = std::get<ShaderAst::BasicType>(targetExprType);
UInt32 exprResultId = EvaluateExpression(node.expression);
UInt32 resultId = m_writer.AllocateResultId();
@ -666,7 +653,7 @@ namespace Nz
// Extract a single component from the vector
assert(node.componentCount == 1);
m_currentBlock->Append(SpirvOp::OpCompositeExtract, m_writer.GetTypeId(targetType), resultId, exprResultId, UInt32(node.components[0]) - UInt32(ShaderNodes::SwizzleComponent::First) );
m_currentBlock->Append(SpirvOp::OpCompositeExtract, m_writer.GetTypeId(targetType), resultId, exprResultId, UInt32(node.components[0]) - UInt32(ShaderAst::SwizzleComponent::First) );
}
PushResultId(resultId);

View File

@ -3,7 +3,6 @@
// For conditions of distribution and use, see copyright notice in Config.hpp
#include <Nazara/Shader/SpirvConstantCache.hpp>
#include <Nazara/Shader/ShaderAst.hpp>
#include <Nazara/Shader/SpirvSection.hpp>
#include <Nazara/Utility/FieldOffsets.hpp>
#include <tsl/ordered_map.h>
@ -536,7 +535,7 @@ namespace Nz
else if constexpr (std::is_same_v<T, Vector2f> || std::is_same_v<T, Vector2i>)
{
return ConstantComposite{
BuildType((std::is_same_v<T, Vector2f>) ? ShaderNodes::BasicType::Float2 : ShaderNodes::BasicType::Int2),
BuildType((std::is_same_v<T, Vector2f>) ? ShaderAst::BasicType::Float2 : ShaderAst::BasicType::Int2),
{
BuildConstant(arg.x),
BuildConstant(arg.y)
@ -546,7 +545,7 @@ namespace Nz
else if constexpr (std::is_same_v<T, Vector3f> || std::is_same_v<T, Vector3i>)
{
return ConstantComposite{
BuildType((std::is_same_v<T, Vector3f>) ? ShaderNodes::BasicType::Float3 : ShaderNodes::BasicType::Int3),
BuildType((std::is_same_v<T, Vector3f>) ? ShaderAst::BasicType::Float3 : ShaderAst::BasicType::Int3),
{
BuildConstant(arg.x),
BuildConstant(arg.y),
@ -557,7 +556,7 @@ namespace Nz
else if constexpr (std::is_same_v<T, Vector4f> || std::is_same_v<T, Vector4i>)
{
return ConstantComposite{
BuildType((std::is_same_v<T, Vector4f>) ? ShaderNodes::BasicType::Float4 : ShaderNodes::BasicType::Int4),
BuildType((std::is_same_v<T, Vector4f>) ? ShaderAst::BasicType::Float4 : ShaderAst::BasicType::Int4),
{
BuildConstant(arg.x),
BuildConstant(arg.y),
@ -571,7 +570,7 @@ namespace Nz
}, value));
}
auto SpirvConstantCache::BuildPointerType(const ShaderNodes::BasicType& type, SpirvStorageClass storageClass) -> TypePtr
auto SpirvConstantCache::BuildPointerType(const ShaderAst::BasicType& type, SpirvStorageClass storageClass) -> TypePtr
{
return std::make_shared<Type>(SpirvConstantCache::Pointer{
SpirvConstantCache::BuildType(type),
@ -579,55 +578,55 @@ namespace Nz
});
}
auto SpirvConstantCache::BuildPointerType(const ShaderAst& shader, const ShaderExpressionType& type, SpirvStorageClass storageClass) -> TypePtr
auto SpirvConstantCache::BuildPointerType(const ShaderAst::ShaderExpressionType& type, SpirvStorageClass storageClass) -> TypePtr
{
return std::make_shared<Type>(SpirvConstantCache::Pointer{
SpirvConstantCache::BuildType(shader, type),
SpirvConstantCache::BuildType(type),
storageClass
});
}
auto SpirvConstantCache::BuildType(const ShaderNodes::BasicType& type) -> TypePtr
auto SpirvConstantCache::BuildType(const ShaderAst::BasicType& type) -> TypePtr
{
return std::make_shared<Type>([&]() -> AnyType
{
switch (type)
{
case ShaderNodes::BasicType::Boolean:
case ShaderAst::BasicType::Boolean:
return Bool{};
case ShaderNodes::BasicType::Float1:
case ShaderAst::BasicType::Float1:
return Float{ 32 };
case ShaderNodes::BasicType::Int1:
case ShaderAst::BasicType::Int1:
return Integer{ 32, true };
case ShaderNodes::BasicType::Float2:
case ShaderNodes::BasicType::Float3:
case ShaderNodes::BasicType::Float4:
case ShaderNodes::BasicType::Int2:
case ShaderNodes::BasicType::Int3:
case ShaderNodes::BasicType::Int4:
case ShaderNodes::BasicType::UInt2:
case ShaderNodes::BasicType::UInt3:
case ShaderNodes::BasicType::UInt4:
case ShaderAst::BasicType::Float2:
case ShaderAst::BasicType::Float3:
case ShaderAst::BasicType::Float4:
case ShaderAst::BasicType::Int2:
case ShaderAst::BasicType::Int3:
case ShaderAst::BasicType::Int4:
case ShaderAst::BasicType::UInt2:
case ShaderAst::BasicType::UInt3:
case ShaderAst::BasicType::UInt4:
{
auto vecType = BuildType(ShaderNodes::Node::GetComponentType(type));
UInt32 componentCount = ShaderNodes::Node::GetComponentCount(type);
auto vecType = BuildType(ShaderAst::GetComponentType(type));
UInt32 componentCount = ShaderAst::GetComponentCount(type);
return Vector{ vecType, componentCount };
}
case ShaderNodes::BasicType::Mat4x4:
return Matrix{ BuildType(ShaderNodes::BasicType::Float4), 4u };
case ShaderAst::BasicType::Mat4x4:
return Matrix{ BuildType(ShaderAst::BasicType::Float4), 4u };
case ShaderNodes::BasicType::UInt1:
case ShaderAst::BasicType::UInt1:
return Integer{ 32, false };
case ShaderNodes::BasicType::Void:
case ShaderAst::BasicType::Void:
return Void{};
case ShaderNodes::BasicType::Sampler2D:
case ShaderAst::BasicType::Sampler2D:
{
auto imageType = Image{
{}, //< qualifier
@ -635,7 +634,7 @@ namespace Nz
{}, //< sampled
SpirvDim::Dim2D, //< dim
SpirvImageFormat::Unknown, //< format
BuildType(ShaderNodes::BasicType::Float1), //< sampledType
BuildType(ShaderAst::BasicType::Float1), //< sampledType
false, //< arrayed,
false //< multisampled
};
@ -648,16 +647,16 @@ namespace Nz
}());
}
auto SpirvConstantCache::BuildType(const ShaderAst& shader, const ShaderExpressionType& type) -> TypePtr
auto SpirvConstantCache::BuildType(const ShaderAst::ShaderExpressionType& type) -> TypePtr
{
return std::visit([&](auto&& arg) -> TypePtr
{
using T = std::decay_t<decltype(arg)>;
if constexpr (std::is_same_v<T, ShaderNodes::BasicType>)
if constexpr (std::is_same_v<T, ShaderAst::BasicType>)
return BuildType(arg);
else if constexpr (std::is_same_v<T, std::string>)
{
// Register struct members type
/*// Register struct members type
const auto& structs = shader.GetStructs();
auto it = std::find_if(structs.begin(), structs.end(), [&](const auto& s) { return s.name == arg; });
if (it == structs.end())
@ -675,7 +674,8 @@ namespace Nz
sMembers.type = BuildType(shader, member.type);
}
return std::make_shared<Type>(std::move(sType));
return std::make_shared<Type>(std::move(sType));*/
return nullptr;
}
else
static_assert(AlwaysFalse<T>::value, "non-exhaustive visitor");

View File

@ -16,7 +16,7 @@ namespace Nz
template<class... Ts> overloaded(Ts...) -> overloaded<Ts...>;
}
UInt32 SpirvExpressionLoad::Evaluate(ShaderNodes::Expression& node)
UInt32 SpirvExpressionLoad::Evaluate(ShaderAst::Expression& node)
{
node.Visit(*this);
@ -41,7 +41,7 @@ namespace Nz
}, m_value);
}
void SpirvExpressionLoad::Visit(ShaderNodes::AccessMember& node)
/*void SpirvExpressionLoad::Visit(ShaderAst::AccessMemberExpression& node)
{
Visit(node.structExpr);
@ -49,6 +49,8 @@ namespace Nz
{
[&](const Pointer& pointer)
{
ShaderAst::ShaderExpressionType exprType = GetExpressionType(node.structExpr);
UInt32 resultId = m_writer.AllocateResultId();
UInt32 pointerType = m_writer.RegisterPointerType(node.exprType, pointer.storage); //< FIXME
UInt32 typeId = m_writer.GetTypeId(node.exprType);
@ -87,40 +89,15 @@ namespace Nz
throw std::runtime_error("an internal error occurred");
}
}, m_value);
}
}*/
void SpirvExpressionLoad::Visit(ShaderNodes::Identifier& node)
void SpirvExpressionLoad::Visit(ShaderAst::IdentifierExpression& node)
{
Visit(node.var);
}
void SpirvExpressionLoad::Visit(ShaderNodes::InputVariable& var)
{
auto inputVar = m_writer.GetInputVariable(var.name);
if (auto resultIdOpt = m_writer.ReadVariable(inputVar, SpirvWriter::OnlyCache{}))
m_value = Value{ *resultIdOpt };
if (node.identifier == "d")
m_value = Value{ m_writer.ReadLocalVariable(node.identifier) };
else
m_value = Pointer{ SpirvStorageClass::Input, inputVar.varId, inputVar.typeId };
}
m_value = Value{ m_writer.ReadParameterVariable(node.identifier) };
void SpirvExpressionLoad::Visit(ShaderNodes::LocalVariable& var)
{
m_value = Value{ m_writer.ReadLocalVariable(var.name) };
}
void SpirvExpressionLoad::Visit(ShaderNodes::ParameterVariable& var)
{
m_value = Value{ m_writer.ReadParameterVariable(var.name) };
}
void SpirvExpressionLoad::Visit(ShaderNodes::UniformVariable& var)
{
auto uniformVar = m_writer.GetUniformVariable(var.name);
if (auto resultIdOpt = m_writer.ReadVariable(uniformVar, SpirvWriter::OnlyCache{}))
m_value = Value{ *resultIdOpt };
else
m_value = Pointer{ SpirvStorageClass::Uniform, uniformVar.varId, uniformVar.typeId };
//Visit(node.var);
}
}

View File

@ -15,9 +15,9 @@ namespace Nz
template<class... Ts> overloaded(Ts...)->overloaded<Ts...>;
}
void SpirvExpressionStore::Store(const ShaderNodes::ExpressionPtr& node, UInt32 resultId)
void SpirvExpressionStore::Store(ShaderAst::ExpressionPtr& node, UInt32 resultId)
{
Visit(node);
node->Visit(*this);
std::visit(overloaded
{
@ -36,7 +36,7 @@ namespace Nz
}, m_value);
}
void SpirvExpressionStore::Visit(ShaderNodes::AccessMember& node)
/*void SpirvExpressionStore::Visit(ShaderAst::AccessMemberExpression& node)
{
Visit(node.structExpr);
@ -70,34 +70,15 @@ namespace Nz
throw std::runtime_error("an internal error occurred");
}
}, m_value);
}
}*/
void SpirvExpressionStore::Visit(ShaderNodes::Identifier& node)
void SpirvExpressionStore::Visit(ShaderAst::IdentifierExpression& node)
{
Visit(node.var);
m_value = LocalVar{ node.identifier };
}
void SpirvExpressionStore::Visit(ShaderNodes::SwizzleOp& node)
void SpirvExpressionStore::Visit(ShaderAst::SwizzleExpression& node)
{
throw std::runtime_error("not yet implemented");
}
void SpirvExpressionStore::Visit(ShaderNodes::BuiltinVariable& var)
{
const auto& outputVar = m_writer.GetBuiltinVariable(var.entry);
m_value = Pointer{ SpirvStorageClass::Output, outputVar.varId };
}
void SpirvExpressionStore::Visit(ShaderNodes::LocalVariable& var)
{
m_value = LocalVar{ var.name };
}
void SpirvExpressionStore::Visit(ShaderNodes::OutputVariable& var)
{
const auto& outputVar = m_writer.GetOutputVariable(var.name);
m_value = Pointer{ SpirvStorageClass::Output, outputVar.varId };
}
}

View File

@ -26,155 +26,131 @@ namespace Nz
{
namespace
{
class PreVisitor : public ShaderAstRecursiveVisitor, public ShaderVarVisitor
class PreVisitor : public ShaderAst::AstRecursiveVisitor
{
public:
using BuiltinContainer = std::unordered_set<std::shared_ptr<const ShaderNodes::BuiltinVariable>>;
using ExtInstList = std::unordered_set<std::string>;
using LocalContainer = std::unordered_set<std::shared_ptr<const ShaderNodes::LocalVariable>>;
using ParameterContainer = std::unordered_set< std::shared_ptr<const ShaderNodes::ParameterVariable>>;
using LocalContainer = std::unordered_set<ShaderAst::ShaderExpressionType>;
PreVisitor(const ShaderAst& shader, const SpirvWriter::States& conditions, SpirvConstantCache& constantCache) :
m_shader(shader),
PreVisitor(ShaderAst::AstCache* cache, const SpirvWriter::States& conditions, SpirvConstantCache& constantCache) :
m_cache(cache),
m_conditions(conditions),
m_constantCache(constantCache)
{
}
using ShaderAstRecursiveVisitor::Visit;
using ShaderVarVisitor::Visit;
void Visit(ShaderNodes::AccessMember& node) override
void Visit(ShaderAst::AccessMemberExpression& node) override
{
for (std::size_t index : node.memberIndices)
m_constantCache.Register(*SpirvConstantCache::BuildConstant(Int32(index)));
/*for (std::size_t index : node.memberIdentifiers)
m_constantCache.Register(*SpirvConstantCache::BuildConstant(Int32(index)));*/
ShaderAstRecursiveVisitor::Visit(node);
AstRecursiveVisitor::Visit(node);
}
void Visit(ShaderNodes::ConditionalExpression& node) override
void Visit(ShaderAst::ConditionalExpression& node) override
{
std::size_t conditionIndex = m_shader.FindConditionByName(node.conditionName);
/*std::size_t conditionIndex = m_shader.FindConditionByName(node.conditionName);
assert(conditionIndex != ShaderAst::InvalidCondition);
if (TestBit<Nz::UInt64>(m_conditions.enabledConditions, conditionIndex))
Visit(node.truePath);
else
Visit(node.falsePath);
Visit(node.falsePath);*/
}
void Visit(ShaderNodes::ConditionalStatement& node) override
void Visit(ShaderAst::ConditionalStatement& node) override
{
std::size_t conditionIndex = m_shader.FindConditionByName(node.conditionName);
/*std::size_t conditionIndex = m_shader.FindConditionByName(node.conditionName);
assert(conditionIndex != ShaderAst::InvalidCondition);
if (TestBit<Nz::UInt64>(m_conditions.enabledConditions, conditionIndex))
Visit(node.statement);
Visit(node.statement);*/
}
void Visit(ShaderNodes::Constant& node) override
void Visit(ShaderAst::ConstantExpression& node) override
{
std::visit([&](auto&& arg)
{
m_constantCache.Register(*SpirvConstantCache::BuildConstant(arg));
}, node.value);
ShaderAstRecursiveVisitor::Visit(node);
AstRecursiveVisitor::Visit(node);
}
void Visit(ShaderNodes::DeclareVariable& node) override
void Visit(ShaderAst::DeclareFunctionStatement& node) override
{
Visit(node.variable);
ShaderAstRecursiveVisitor::Visit(node);
m_constantCache.Register(*SpirvConstantCache::BuildType(node.returnType));
for (auto& parameter : node.parameters)
m_constantCache.Register(*SpirvConstantCache::BuildType(parameter.type));
}
void Visit(ShaderNodes::Identifier& node) override
void Visit(ShaderAst::DeclareStructStatement& node) override
{
Visit(node.var);
ShaderAstRecursiveVisitor::Visit(node);
for (auto& field : node.description.members)
m_constantCache.Register(*SpirvConstantCache::BuildType(field.type));
}
void Visit(ShaderNodes::IntrinsicCall& node) override
void Visit(ShaderAst::DeclareVariableStatement& node) override
{
ShaderAstRecursiveVisitor::Visit(node);
variableTypes.insert(node.varType);
AstRecursiveVisitor::Visit(node);
}
void Visit(ShaderAst::IdentifierExpression& node) override
{
variableTypes.insert(GetExpressionType(node, m_cache));
AstRecursiveVisitor::Visit(node);
}
void Visit(ShaderAst::IntrinsicExpression& node) override
{
AstRecursiveVisitor::Visit(node);
switch (node.intrinsic)
{
// Require GLSL.std.450
case ShaderNodes::IntrinsicType::CrossProduct:
case ShaderAst::IntrinsicType::CrossProduct:
extInsts.emplace("GLSL.std.450");
break;
// Part of SPIR-V core
case ShaderNodes::IntrinsicType::DotProduct:
case ShaderAst::IntrinsicType::DotProduct:
break;
}
}
void Visit(ShaderNodes::BuiltinVariable& var) override
{
builtinVars.insert(std::static_pointer_cast<const ShaderNodes::BuiltinVariable>(var.shared_from_this()));
}
void Visit(ShaderNodes::InputVariable& /*var*/) override
{
/* Handled by ShaderAst */
}
void Visit(ShaderNodes::LocalVariable& var) override
{
localVars.insert(std::static_pointer_cast<const ShaderNodes::LocalVariable>(var.shared_from_this()));
}
void Visit(ShaderNodes::OutputVariable& /*var*/) override
{
/* Handled by ShaderAst */
}
void Visit(ShaderNodes::ParameterVariable& var) override
{
paramVars.insert(std::static_pointer_cast<const ShaderNodes::ParameterVariable>(var.shared_from_this()));
}
void Visit(ShaderNodes::UniformVariable& /*var*/) override
{
/* Handled by ShaderAst */
}
BuiltinContainer builtinVars;
ExtInstList extInsts;
LocalContainer localVars;
ParameterContainer paramVars;
LocalContainer variableTypes;
private:
const ShaderAst& m_shader;
ShaderAst::AstCache* m_cache;
const SpirvWriter::States& m_conditions;
SpirvConstantCache& m_constantCache;
};
template<typename T>
constexpr ShaderNodes::BasicType GetBasicType()
constexpr ShaderAst::BasicType GetBasicType()
{
if constexpr (std::is_same_v<T, bool>)
return ShaderNodes::BasicType::Boolean;
return ShaderAst::BasicType::Boolean;
else if constexpr (std::is_same_v<T, float>)
return(ShaderNodes::BasicType::Float1);
return(ShaderAst::BasicType::Float1);
else if constexpr (std::is_same_v<T, Int32>)
return(ShaderNodes::BasicType::Int1);
return(ShaderAst::BasicType::Int1);
else if constexpr (std::is_same_v<T, Vector2f>)
return(ShaderNodes::BasicType::Float2);
return(ShaderAst::BasicType::Float2);
else if constexpr (std::is_same_v<T, Vector3f>)
return(ShaderNodes::BasicType::Float3);
return(ShaderAst::BasicType::Float3);
else if constexpr (std::is_same_v<T, Vector4f>)
return(ShaderNodes::BasicType::Float4);
return(ShaderAst::BasicType::Float4);
else if constexpr (std::is_same_v<T, Vector2i32>)
return(ShaderNodes::BasicType::Int2);
return(ShaderAst::BasicType::Int2);
else if constexpr (std::is_same_v<T, Vector3i32>)
return(ShaderNodes::BasicType::Int3);
return(ShaderAst::BasicType::Int3);
else if constexpr (std::is_same_v<T, Vector4i32>)
return(ShaderNodes::BasicType::Int4);
return(ShaderAst::BasicType::Int4);
else
static_assert(AlwaysFalse<T>::value, "unhandled type");
}
@ -198,7 +174,7 @@ namespace Nz
tsl::ordered_map<std::string, ExtVar> parameterIds;
tsl::ordered_map<std::string, ExtVar> uniformIds;
std::unordered_map<std::string, UInt32> extensionInstructions;
std::unordered_map<ShaderNodes::BuiltinEntry, ExtVar> builtinIds;
std::unordered_map<ShaderAst::BuiltinEntry, ExtVar> builtinIds;
std::unordered_map<std::string, UInt32> varToResult;
std::vector<Func> funcs;
std::vector<SpirvBlock> functionBlocks;
@ -219,13 +195,12 @@ namespace Nz
{
}
std::vector<UInt32> SpirvWriter::Generate(const ShaderAst& shader, const States& conditions)
std::vector<UInt32> SpirvWriter::Generate(ShaderAst::StatementPtr& shader, const States& conditions)
{
std::string error;
if (!ValidateShader(shader, &error))
if (!ShaderAst::ValidateAst(shader, &error, &m_context.cache))
throw std::runtime_error("Invalid shader AST: " + error);
m_context.shader = &shader;
m_context.states = &conditions;
State state;
@ -235,23 +210,19 @@ namespace Nz
m_currentState = nullptr;
});
std::vector<ShaderNodes::StatementPtr> functionStatements;
std::vector<ShaderAst::StatementPtr> functionStatements;
ShaderAstCloner cloner;
PreVisitor preVisitor(shader, conditions, state.constantTypeCache);
for (const auto& func : shader.GetFunctions())
{
functionStatements.emplace_back(cloner.Clone(func.statement));
preVisitor.Visit(func.statement);
}
ShaderAst::AstCloner cloner;
// Register all extended instruction sets
PreVisitor preVisitor(&m_context.cache, conditions, state.constantTypeCache);
shader->Visit(preVisitor);
for (const std::string& extInst : preVisitor.extInsts)
state.extensionInstructions[extInst] = AllocateResultId();
// Register all types
for (const auto& func : shader.GetFunctions())
/*for (const auto& func : shader.GetFunctions())
{
RegisterType(func.returnType);
for (const auto& param : func.parameters)
@ -270,8 +241,8 @@ namespace Nz
for (const auto& func : shader.GetFunctions())
RegisterFunctionType(func.returnType, func.parameters);
for (const auto& local : preVisitor.localVars)
RegisterType(local->type);
for (const auto& type : preVisitor.variableTypes)
RegisterType(type);
for (const auto& builtin : preVisitor.builtinVars)
RegisterType(builtin->type);
@ -283,7 +254,7 @@ namespace Nz
SpirvBuiltIn builtinDecoration;
switch (builtin->entry)
{
case ShaderNodes::BuiltinEntry::VertexPosition:
case ShaderAst::BuiltinEntry::VertexPosition:
variable.debugName = "builtin_VertexPosition";
variable.storageClass = SpirvStorageClass::Output;
@ -294,10 +265,10 @@ namespace Nz
throw std::runtime_error("unexpected builtin type");
}
const ShaderExpressionType& builtinExprType = builtin->type;
const ShaderAst::ShaderExpressionType& builtinExprType = builtin->type;
assert(IsBasicType(builtinExprType));
ShaderNodes::BasicType builtinType = std::get<ShaderNodes::BasicType>(builtinExprType);
ShaderAst::BasicType builtinType = std::get<ShaderAst::BasicType>(builtinExprType);
variable.type = SpirvConstantCache::BuildPointerType(builtinType, variable.storageClass);
@ -420,7 +391,7 @@ namespace Nz
if (!state.functionBlocks.back().IsTerminated())
{
assert(func.returnType == ShaderExpressionType(ShaderNodes::BasicType::Void));
assert(func.returnType == ShaderAst::ShaderExpressionType(ShaderAst::BasicType::Void));
state.functionBlocks.back().Append(SpirvOp::OpReturn);
}
@ -475,14 +446,14 @@ namespace Nz
if (m_context.shader->GetStage() == ShaderStageType::Fragment)
state.header.Append(SpirvOp::OpExecutionMode, entryFunc.id, SpvExecutionModeOriginUpperLeft);
}
}*/
std::vector<UInt32> ret;
MergeSections(ret, state.header);
/*MergeSections(ret, state.header);
MergeSections(ret, state.debugInfo);
MergeSections(ret, state.annotations);
MergeSections(ret, state.constants);
MergeSections(ret, state.instructions);
MergeSections(ret, state.instructions);*/
return ret;
}
@ -516,16 +487,16 @@ namespace Nz
m_currentState->header.Append(SpirvOp::OpMemoryModel, SpvAddressingModelLogical, SpvMemoryModelGLSL450);
}
SpirvConstantCache::Function SpirvWriter::BuildFunctionType(ShaderExpressionType retType, const std::vector<ShaderAst::FunctionParameter>& parameters)
SpirvConstantCache::Function SpirvWriter::BuildFunctionType(ShaderAst::ShaderExpressionType retType, const std::vector<FunctionParameter>& parameters)
{
std::vector<SpirvConstantCache::TypePtr> parameterTypes;
parameterTypes.reserve(parameters.size());
for (const auto& parameter : parameters)
parameterTypes.push_back(SpirvConstantCache::BuildPointerType(*m_context.shader, parameter.type, SpirvStorageClass::Function));
parameterTypes.push_back(SpirvConstantCache::BuildPointerType(parameter.type, SpirvStorageClass::Function));
return SpirvConstantCache::Function{
SpirvConstantCache::BuildType(*m_context.shader, retType),
SpirvConstantCache::BuildType(retType),
std::move(parameterTypes)
};
}
@ -535,12 +506,12 @@ namespace Nz
return m_currentState->constantTypeCache.GetId(*SpirvConstantCache::BuildConstant(value));
}
UInt32 SpirvWriter::GetFunctionTypeId(ShaderExpressionType retType, const std::vector<ShaderAst::FunctionParameter>& parameters)
UInt32 SpirvWriter::GetFunctionTypeId(ShaderAst::ShaderExpressionType retType, const std::vector<FunctionParameter>& parameters)
{
return m_currentState->constantTypeCache.GetId({ BuildFunctionType(retType, parameters) });
}
auto SpirvWriter::GetBuiltinVariable(ShaderNodes::BuiltinEntry builtin) const -> const ExtVar&
auto SpirvWriter::GetBuiltinVariable(ShaderAst::BuiltinEntry builtin) const -> const ExtVar&
{
auto it = m_currentState->builtinIds.find(builtin);
assert(it != m_currentState->builtinIds.end());
@ -572,14 +543,14 @@ namespace Nz
return it.value();
}
UInt32 SpirvWriter::GetPointerTypeId(const ShaderExpressionType& type, SpirvStorageClass storageClass) const
UInt32 SpirvWriter::GetPointerTypeId(const ShaderAst::ShaderExpressionType& type, SpirvStorageClass storageClass) const
{
return m_currentState->constantTypeCache.GetId(*SpirvConstantCache::BuildPointerType(*m_context.shader, type, storageClass));
return m_currentState->constantTypeCache.GetId(*SpirvConstantCache::BuildPointerType(type, storageClass));
}
UInt32 SpirvWriter::GetTypeId(const ShaderExpressionType& type) const
UInt32 SpirvWriter::GetTypeId(const ShaderAst::ShaderExpressionType& type) const
{
return m_currentState->constantTypeCache.GetId(*SpirvConstantCache::BuildType(*m_context.shader, type));
return m_currentState->constantTypeCache.GetId(*SpirvConstantCache::BuildType(type));
}
UInt32 SpirvWriter::ReadInputVariable(const std::string& name)
@ -673,20 +644,20 @@ namespace Nz
return m_currentState->constantTypeCache.Register(*SpirvConstantCache::BuildConstant(value));
}
UInt32 SpirvWriter::RegisterFunctionType(ShaderExpressionType retType, const std::vector<ShaderAst::FunctionParameter>& parameters)
UInt32 SpirvWriter::RegisterFunctionType(ShaderAst::ShaderExpressionType retType, const std::vector<FunctionParameter>& parameters)
{
return m_currentState->constantTypeCache.Register({ BuildFunctionType(retType, parameters) });
}
UInt32 SpirvWriter::RegisterPointerType(ShaderExpressionType type, SpirvStorageClass storageClass)
UInt32 SpirvWriter::RegisterPointerType(ShaderAst::ShaderExpressionType type, SpirvStorageClass storageClass)
{
return m_currentState->constantTypeCache.Register(*SpirvConstantCache::BuildPointerType(*m_context.shader, type, storageClass));
return m_currentState->constantTypeCache.Register(*SpirvConstantCache::BuildPointerType(type, storageClass));
}
UInt32 SpirvWriter::RegisterType(ShaderExpressionType type)
UInt32 SpirvWriter::RegisterType(ShaderAst::ShaderExpressionType type)
{
assert(m_currentState);
return m_currentState->constantTypeCache.Register(*SpirvConstantCache::BuildType(*m_context.shader, type));
return m_currentState->constantTypeCache.Register(*SpirvConstantCache::BuildType(type));
}
void SpirvWriter::WriteLocalVariable(std::string name, UInt32 resultId)