Add conditional expression/statement support for shaders

This commit is contained in:
Jérôme Leclercq
2020-11-19 13:56:54 +01:00
parent ad88561245
commit 960817a1f1
45 changed files with 996 additions and 56 deletions

View File

@@ -16,7 +16,6 @@
#include <set>
#include <sstream>
#include <string>
#include <unordered_map>
namespace Nz
{
@@ -31,7 +30,7 @@ namespace Nz
GlslWriter(GlslWriter&&) = delete;
~GlslWriter() = default;
std::string Generate(const ShaderAst& shader) override;
std::string Generate(const ShaderAst& shader, const States& conditions = {});
void SetEnv(Environment environment);
@@ -70,6 +69,8 @@ namespace Nz
void Visit(ShaderNodes::BinaryOp& node) override;
void Visit(ShaderNodes::BuiltinVariable& var) override;
void Visit(ShaderNodes::Cast& node) override;
void Visit(ShaderNodes::ConditionalExpression& node) override;
void Visit(ShaderNodes::ConditionalStatement& node) override;
void Visit(ShaderNodes::Constant& node) override;
void Visit(ShaderNodes::DeclareVariable& node) override;
void Visit(ShaderNodes::ExpressionStatement& node) override;
@@ -91,6 +92,7 @@ namespace Nz
{
const ShaderAst* shader = nullptr;
const ShaderAst::Function* currentFunction = nullptr;
const States* states = nullptr;
};
struct State

View File

@@ -20,6 +20,7 @@ namespace Nz
class NAZARA_SHADER_API ShaderAst
{
public:
struct Condition;
struct Function;
struct FunctionParameter;
struct InputOutput;
@@ -33,12 +34,16 @@ namespace Nz
ShaderAst(ShaderAst&&) noexcept = default;
~ShaderAst() = default;
void AddCondition(std::string name);
void AddFunction(std::string name, ShaderNodes::StatementPtr statement, std::vector<FunctionParameter> parameters = {}, ShaderNodes::BasicType returnType = ShaderNodes::BasicType::Void);
void AddInput(std::string name, ShaderExpressionType type, std::optional<std::size_t> locationIndex = {});
void AddOutput(std::string name, ShaderExpressionType type, std::optional<std::size_t> locationIndex = {});
void AddStruct(std::string name, std::vector<StructMember> members);
void AddUniform(std::string name, ShaderExpressionType type, std::optional<std::size_t> bindingIndex = {}, std::optional<ShaderNodes::MemoryLayout> memoryLayout = {});
inline const Condition& GetCondition(std::size_t i) const;
inline std::size_t GetConditionCount() const;
inline const std::vector<Condition>& GetConditions() const;
inline const Function& GetFunction(std::size_t i) const;
inline std::size_t GetFunctionCount() const;
inline const std::vector<Function>& GetFunctions() const;
@@ -59,6 +64,11 @@ namespace Nz
ShaderAst& operator=(const ShaderAst&) = default;
ShaderAst& operator=(ShaderAst&&) noexcept = default;
struct Condition
{
std::string name;
};
struct VariableBase
{
std::string name;
@@ -101,6 +111,7 @@ namespace Nz
};
private:
std::vector<Condition> m_conditions;
std::vector<Function> m_functions;
std::vector<InputOutput> m_inputs;
std::vector<InputOutput> m_outputs;

View File

@@ -12,6 +12,22 @@ namespace Nz
{
}
inline auto Nz::ShaderAst::GetCondition(std::size_t i) const -> const Condition&
{
assert(i < m_functions.size());
return m_conditions[i];
}
inline std::size_t ShaderAst::GetConditionCount() const
{
return m_conditions.size();
}
inline auto ShaderAst::GetConditions() const -> const std::vector<Condition>&
{
return m_conditions;
}
inline auto ShaderAst::GetFunction(std::size_t i) const -> const Function&
{
assert(i < m_functions.size());

View File

@@ -38,6 +38,8 @@ namespace Nz
void Visit(ShaderNodes::BinaryOp& node) override;
void Visit(ShaderNodes::Branch& node) override;
void Visit(ShaderNodes::Cast& node) override;
void Visit(ShaderNodes::ConditionalExpression& node) override;
void Visit(ShaderNodes::ConditionalStatement& node) override;
void Visit(ShaderNodes::Constant& node) override;
void Visit(ShaderNodes::DeclareVariable& node) override;
void Visit(ShaderNodes::ExpressionStatement& node) override;

View File

@@ -26,6 +26,8 @@ namespace Nz
void Visit(ShaderNodes::BinaryOp& node) override;
void Visit(ShaderNodes::Branch& node) override;
void Visit(ShaderNodes::Cast& node) override;
void Visit(ShaderNodes::ConditionalExpression& node) override;
void Visit(ShaderNodes::ConditionalStatement& node) override;
void Visit(ShaderNodes::Constant& node) override;
void Visit(ShaderNodes::DeclareVariable& node) override;
void Visit(ShaderNodes::ExpressionStatement& node) override;

View File

@@ -31,6 +31,8 @@ namespace Nz
void Serialize(ShaderNodes::BuiltinVariable& var);
void Serialize(ShaderNodes::Branch& node);
void Serialize(ShaderNodes::Cast& node);
void Serialize(ShaderNodes::ConditionalExpression& node);
void Serialize(ShaderNodes::ConditionalStatement& node);
void Serialize(ShaderNodes::Constant& node);
void Serialize(ShaderNodes::DeclareVariable& node);
void Serialize(ShaderNodes::ExpressionStatement& node);

View File

@@ -41,6 +41,8 @@ namespace Nz
void Visit(ShaderNodes::BinaryOp& node) override;
void Visit(ShaderNodes::Branch& node) override;
void Visit(ShaderNodes::Cast& node) override;
void Visit(ShaderNodes::ConditionalExpression& node) override;
void Visit(ShaderNodes::ConditionalStatement& node) override;
void Visit(ShaderNodes::Constant& node) override;
void Visit(ShaderNodes::DeclareVariable& node) override;
void Visit(ShaderNodes::ExpressionStatement& node) override;

View File

@@ -10,8 +10,6 @@
#include <Nazara/Prerequisites.hpp>
#include <Nazara/Shader/Config.hpp>
#include <Nazara/Shader/ShaderNodes.hpp>
#include <string>
#include <unordered_set>
namespace Nz
{
@@ -23,16 +21,14 @@ namespace Nz
ShaderAstVisitor(ShaderAstVisitor&&) = delete;
virtual ~ShaderAstVisitor();
void EnableCondition(const std::string& name, bool cond);
bool IsConditionEnabled(const std::string& name) const;
void Visit(const ShaderNodes::NodePtr& node);
virtual void Visit(ShaderNodes::AccessMember& node) = 0;
virtual void Visit(ShaderNodes::AssignOp& node) = 0;
virtual void Visit(ShaderNodes::BinaryOp& node) = 0;
virtual void Visit(ShaderNodes::Branch& node) = 0;
virtual void Visit(ShaderNodes::Cast& node) = 0;
virtual void Visit(ShaderNodes::ConditionalExpression& node) = 0;
virtual void Visit(ShaderNodes::ConditionalStatement& node) = 0;
virtual void Visit(ShaderNodes::Constant& node) = 0;
virtual void Visit(ShaderNodes::DeclareVariable& node) = 0;
virtual void Visit(ShaderNodes::ExpressionStatement& node) = 0;
@@ -44,9 +40,6 @@ namespace Nz
ShaderAstVisitor& operator=(const ShaderAstVisitor&) = delete;
ShaderAstVisitor& operator=(ShaderAstVisitor&&) = delete;
private:
std::unordered_set<std::string> m_conditions;
};
}

View File

@@ -22,6 +22,8 @@ namespace Nz
void Visit(ShaderNodes::BinaryOp& node) override;
void Visit(ShaderNodes::Branch& node) override;
void Visit(ShaderNodes::Cast& node) override;
void Visit(ShaderNodes::ConditionalExpression& node) override;
void Visit(ShaderNodes::ConditionalStatement& node) override;
void Visit(ShaderNodes::Constant& node) override;
void Visit(ShaderNodes::DeclareVariable& node) override;
void Visit(ShaderNodes::ExpressionStatement& node) override;

View File

@@ -50,6 +50,7 @@ namespace Nz::ShaderBuilder
constexpr BuiltinBuilder Builtin;
constexpr GenBuilder<ShaderNodes::StatementBlock> Block;
constexpr GenBuilder<ShaderNodes::Branch> Branch;
constexpr GenBuilder<ShaderNodes::ConditionalExpression> ConditionalExpression;
constexpr GenBuilder<ShaderNodes::ConditionalStatement> ConditionalStatement;
constexpr GenBuilder<ShaderNodes::Constant> Constant;
constexpr GenBuilder<ShaderNodes::DeclareVariable> DeclareVariable;

View File

@@ -78,6 +78,7 @@ namespace Nz::ShaderNodes
Branch,
Cast,
Constant,
ConditionalExpression,
ConditionalStatement,
DeclareVariable,
ExpressionStatement,

View File

@@ -217,6 +217,20 @@ namespace Nz
static inline std::shared_ptr<Cast> Build(BasicType castTo, ExpressionPtr* expressions, std::size_t expressionCount);
};
struct NAZARA_SHADER_API ConditionalExpression : public Expression
{
inline ConditionalExpression();
ShaderExpressionType GetExpressionType() const override;
void Visit(ShaderAstVisitor& visitor) override;
std::string conditionName;
ExpressionPtr falsePath;
ExpressionPtr truePath;
static inline std::shared_ptr<ConditionalExpression> Build(std::string condition, ExpressionPtr truePath, ExpressionPtr falsePath);
};
struct NAZARA_SHADER_API Constant : public Expression
{
inline Constant();

View File

@@ -263,6 +263,20 @@ namespace Nz::ShaderNodes
return node;
}
inline ConditionalExpression::ConditionalExpression() :
Expression(NodeType::ConditionalExpression)
{
}
inline std::shared_ptr<ConditionalExpression> ShaderNodes::ConditionalExpression::Build(std::string condition, ExpressionPtr truePath, ExpressionPtr falsePath)
{
auto node = std::make_shared<ConditionalExpression>();
node->conditionName = std::move(condition);
node->falsePath = std::move(falsePath);
node->truePath = std::move(truePath);
return node;
}
inline Constant::Constant() :
Expression(NodeType::Constant)

View File

@@ -10,6 +10,7 @@
#include <Nazara/Prerequisites.hpp>
#include <Nazara/Shader/Config.hpp>
#include <string>
#include <unordered_set>
namespace Nz
{
@@ -18,12 +19,17 @@ namespace Nz
class NAZARA_SHADER_API ShaderWriter
{
public:
struct States;
ShaderWriter() = default;
ShaderWriter(const ShaderWriter&) = default;
ShaderWriter(ShaderWriter&&) = default;
virtual ~ShaderWriter();
virtual std::string Generate(const ShaderAst& shader) = 0;
struct States
{
std::unordered_set<std::string> enabledConditions;
};
};
}

View File

@@ -32,6 +32,8 @@ namespace Nz
void Visit(ShaderNodes::AssignOp& node) override;
void Visit(ShaderNodes::BinaryOp& node) override;
void Visit(ShaderNodes::Cast& node) override;
void Visit(ShaderNodes::ConditionalExpression& node) override;
void Visit(ShaderNodes::ConditionalStatement& node) override;
void Visit(ShaderNodes::Constant& node) override;
void Visit(ShaderNodes::DeclareVariable& node) override;
void Visit(ShaderNodes::ExpressionStatement& node) override;

View File

@@ -23,7 +23,7 @@ namespace Nz
{
class SpirvSection;
class NAZARA_SHADER_API SpirvWriter
class NAZARA_SHADER_API SpirvWriter : public ShaderWriter
{
friend class SpirvAstVisitor;
friend class SpirvExpressionLoad;
@@ -38,7 +38,7 @@ namespace Nz
SpirvWriter(SpirvWriter&&) = delete;
~SpirvWriter() = default;
std::vector<UInt32> Generate(const ShaderAst& shader);
std::vector<UInt32> Generate(const ShaderAst& shader, const States& conditions = {});
void SetEnv(Environment environment);
@@ -66,6 +66,8 @@ namespace Nz
UInt32 GetPointerTypeId(const ShaderExpressionType& type, SpirvStorageClass storageClass) const;
UInt32 GetTypeId(const ShaderExpressionType& type) const;
inline bool IsConditionEnabled(const std::string& condition) const;
UInt32 ReadInputVariable(const std::string& name);
std::optional<UInt32> ReadInputVariable(const std::string& name, OnlyCache);
UInt32 ReadLocalVariable(const std::string& name);
@@ -88,6 +90,7 @@ namespace Nz
{
const ShaderAst* shader = nullptr;
const ShaderAst::Function* currentFunction = nullptr;
const States* states = nullptr;
};
struct ExtVar

View File

@@ -7,6 +7,10 @@
namespace Nz
{
inline bool SpirvWriter::IsConditionEnabled(const std::string& condition) const
{
return m_context.states->enabledConditions.find(condition) != m_context.states->enabledConditions.end();
}
}
#include <Nazara/Shader/DebugOff.hpp>