Add conditional expression/statement support for shaders
This commit is contained in:
@@ -16,7 +16,6 @@
|
||||
#include <set>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
|
||||
namespace Nz
|
||||
{
|
||||
@@ -31,7 +30,7 @@ namespace Nz
|
||||
GlslWriter(GlslWriter&&) = delete;
|
||||
~GlslWriter() = default;
|
||||
|
||||
std::string Generate(const ShaderAst& shader) override;
|
||||
std::string Generate(const ShaderAst& shader, const States& conditions = {});
|
||||
|
||||
void SetEnv(Environment environment);
|
||||
|
||||
@@ -70,6 +69,8 @@ namespace Nz
|
||||
void Visit(ShaderNodes::BinaryOp& node) override;
|
||||
void Visit(ShaderNodes::BuiltinVariable& var) override;
|
||||
void Visit(ShaderNodes::Cast& node) override;
|
||||
void Visit(ShaderNodes::ConditionalExpression& node) override;
|
||||
void Visit(ShaderNodes::ConditionalStatement& node) override;
|
||||
void Visit(ShaderNodes::Constant& node) override;
|
||||
void Visit(ShaderNodes::DeclareVariable& node) override;
|
||||
void Visit(ShaderNodes::ExpressionStatement& node) override;
|
||||
@@ -91,6 +92,7 @@ namespace Nz
|
||||
{
|
||||
const ShaderAst* shader = nullptr;
|
||||
const ShaderAst::Function* currentFunction = nullptr;
|
||||
const States* states = nullptr;
|
||||
};
|
||||
|
||||
struct State
|
||||
|
||||
@@ -20,6 +20,7 @@ namespace Nz
|
||||
class NAZARA_SHADER_API ShaderAst
|
||||
{
|
||||
public:
|
||||
struct Condition;
|
||||
struct Function;
|
||||
struct FunctionParameter;
|
||||
struct InputOutput;
|
||||
@@ -33,12 +34,16 @@ namespace Nz
|
||||
ShaderAst(ShaderAst&&) noexcept = default;
|
||||
~ShaderAst() = default;
|
||||
|
||||
void AddCondition(std::string name);
|
||||
void AddFunction(std::string name, ShaderNodes::StatementPtr statement, std::vector<FunctionParameter> parameters = {}, ShaderNodes::BasicType returnType = ShaderNodes::BasicType::Void);
|
||||
void AddInput(std::string name, ShaderExpressionType type, std::optional<std::size_t> locationIndex = {});
|
||||
void AddOutput(std::string name, ShaderExpressionType type, std::optional<std::size_t> locationIndex = {});
|
||||
void AddStruct(std::string name, std::vector<StructMember> members);
|
||||
void AddUniform(std::string name, ShaderExpressionType type, std::optional<std::size_t> bindingIndex = {}, std::optional<ShaderNodes::MemoryLayout> memoryLayout = {});
|
||||
|
||||
inline const Condition& GetCondition(std::size_t i) const;
|
||||
inline std::size_t GetConditionCount() const;
|
||||
inline const std::vector<Condition>& GetConditions() const;
|
||||
inline const Function& GetFunction(std::size_t i) const;
|
||||
inline std::size_t GetFunctionCount() const;
|
||||
inline const std::vector<Function>& GetFunctions() const;
|
||||
@@ -59,6 +64,11 @@ namespace Nz
|
||||
ShaderAst& operator=(const ShaderAst&) = default;
|
||||
ShaderAst& operator=(ShaderAst&&) noexcept = default;
|
||||
|
||||
struct Condition
|
||||
{
|
||||
std::string name;
|
||||
};
|
||||
|
||||
struct VariableBase
|
||||
{
|
||||
std::string name;
|
||||
@@ -101,6 +111,7 @@ namespace Nz
|
||||
};
|
||||
|
||||
private:
|
||||
std::vector<Condition> m_conditions;
|
||||
std::vector<Function> m_functions;
|
||||
std::vector<InputOutput> m_inputs;
|
||||
std::vector<InputOutput> m_outputs;
|
||||
|
||||
@@ -12,6 +12,22 @@ namespace Nz
|
||||
{
|
||||
}
|
||||
|
||||
inline auto Nz::ShaderAst::GetCondition(std::size_t i) const -> const Condition&
|
||||
{
|
||||
assert(i < m_functions.size());
|
||||
return m_conditions[i];
|
||||
}
|
||||
|
||||
inline std::size_t ShaderAst::GetConditionCount() const
|
||||
{
|
||||
return m_conditions.size();
|
||||
}
|
||||
|
||||
inline auto ShaderAst::GetConditions() const -> const std::vector<Condition>&
|
||||
{
|
||||
return m_conditions;
|
||||
}
|
||||
|
||||
inline auto ShaderAst::GetFunction(std::size_t i) const -> const Function&
|
||||
{
|
||||
assert(i < m_functions.size());
|
||||
|
||||
@@ -38,6 +38,8 @@ namespace Nz
|
||||
void Visit(ShaderNodes::BinaryOp& node) override;
|
||||
void Visit(ShaderNodes::Branch& node) override;
|
||||
void Visit(ShaderNodes::Cast& node) override;
|
||||
void Visit(ShaderNodes::ConditionalExpression& node) override;
|
||||
void Visit(ShaderNodes::ConditionalStatement& node) override;
|
||||
void Visit(ShaderNodes::Constant& node) override;
|
||||
void Visit(ShaderNodes::DeclareVariable& node) override;
|
||||
void Visit(ShaderNodes::ExpressionStatement& node) override;
|
||||
|
||||
@@ -26,6 +26,8 @@ namespace Nz
|
||||
void Visit(ShaderNodes::BinaryOp& node) override;
|
||||
void Visit(ShaderNodes::Branch& node) override;
|
||||
void Visit(ShaderNodes::Cast& node) override;
|
||||
void Visit(ShaderNodes::ConditionalExpression& node) override;
|
||||
void Visit(ShaderNodes::ConditionalStatement& node) override;
|
||||
void Visit(ShaderNodes::Constant& node) override;
|
||||
void Visit(ShaderNodes::DeclareVariable& node) override;
|
||||
void Visit(ShaderNodes::ExpressionStatement& node) override;
|
||||
|
||||
@@ -31,6 +31,8 @@ namespace Nz
|
||||
void Serialize(ShaderNodes::BuiltinVariable& var);
|
||||
void Serialize(ShaderNodes::Branch& node);
|
||||
void Serialize(ShaderNodes::Cast& node);
|
||||
void Serialize(ShaderNodes::ConditionalExpression& node);
|
||||
void Serialize(ShaderNodes::ConditionalStatement& node);
|
||||
void Serialize(ShaderNodes::Constant& node);
|
||||
void Serialize(ShaderNodes::DeclareVariable& node);
|
||||
void Serialize(ShaderNodes::ExpressionStatement& node);
|
||||
|
||||
@@ -41,6 +41,8 @@ namespace Nz
|
||||
void Visit(ShaderNodes::BinaryOp& node) override;
|
||||
void Visit(ShaderNodes::Branch& node) override;
|
||||
void Visit(ShaderNodes::Cast& node) override;
|
||||
void Visit(ShaderNodes::ConditionalExpression& node) override;
|
||||
void Visit(ShaderNodes::ConditionalStatement& node) override;
|
||||
void Visit(ShaderNodes::Constant& node) override;
|
||||
void Visit(ShaderNodes::DeclareVariable& node) override;
|
||||
void Visit(ShaderNodes::ExpressionStatement& node) override;
|
||||
|
||||
@@ -10,8 +10,6 @@
|
||||
#include <Nazara/Prerequisites.hpp>
|
||||
#include <Nazara/Shader/Config.hpp>
|
||||
#include <Nazara/Shader/ShaderNodes.hpp>
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
|
||||
namespace Nz
|
||||
{
|
||||
@@ -23,16 +21,14 @@ namespace Nz
|
||||
ShaderAstVisitor(ShaderAstVisitor&&) = delete;
|
||||
virtual ~ShaderAstVisitor();
|
||||
|
||||
void EnableCondition(const std::string& name, bool cond);
|
||||
|
||||
bool IsConditionEnabled(const std::string& name) const;
|
||||
|
||||
void Visit(const ShaderNodes::NodePtr& node);
|
||||
virtual void Visit(ShaderNodes::AccessMember& node) = 0;
|
||||
virtual void Visit(ShaderNodes::AssignOp& node) = 0;
|
||||
virtual void Visit(ShaderNodes::BinaryOp& node) = 0;
|
||||
virtual void Visit(ShaderNodes::Branch& node) = 0;
|
||||
virtual void Visit(ShaderNodes::Cast& node) = 0;
|
||||
virtual void Visit(ShaderNodes::ConditionalExpression& node) = 0;
|
||||
virtual void Visit(ShaderNodes::ConditionalStatement& node) = 0;
|
||||
virtual void Visit(ShaderNodes::Constant& node) = 0;
|
||||
virtual void Visit(ShaderNodes::DeclareVariable& node) = 0;
|
||||
virtual void Visit(ShaderNodes::ExpressionStatement& node) = 0;
|
||||
@@ -44,9 +40,6 @@ namespace Nz
|
||||
|
||||
ShaderAstVisitor& operator=(const ShaderAstVisitor&) = delete;
|
||||
ShaderAstVisitor& operator=(ShaderAstVisitor&&) = delete;
|
||||
|
||||
private:
|
||||
std::unordered_set<std::string> m_conditions;
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
@@ -22,6 +22,8 @@ namespace Nz
|
||||
void Visit(ShaderNodes::BinaryOp& node) override;
|
||||
void Visit(ShaderNodes::Branch& node) override;
|
||||
void Visit(ShaderNodes::Cast& node) override;
|
||||
void Visit(ShaderNodes::ConditionalExpression& node) override;
|
||||
void Visit(ShaderNodes::ConditionalStatement& node) override;
|
||||
void Visit(ShaderNodes::Constant& node) override;
|
||||
void Visit(ShaderNodes::DeclareVariable& node) override;
|
||||
void Visit(ShaderNodes::ExpressionStatement& node) override;
|
||||
|
||||
@@ -50,6 +50,7 @@ namespace Nz::ShaderBuilder
|
||||
constexpr BuiltinBuilder Builtin;
|
||||
constexpr GenBuilder<ShaderNodes::StatementBlock> Block;
|
||||
constexpr GenBuilder<ShaderNodes::Branch> Branch;
|
||||
constexpr GenBuilder<ShaderNodes::ConditionalExpression> ConditionalExpression;
|
||||
constexpr GenBuilder<ShaderNodes::ConditionalStatement> ConditionalStatement;
|
||||
constexpr GenBuilder<ShaderNodes::Constant> Constant;
|
||||
constexpr GenBuilder<ShaderNodes::DeclareVariable> DeclareVariable;
|
||||
|
||||
@@ -78,6 +78,7 @@ namespace Nz::ShaderNodes
|
||||
Branch,
|
||||
Cast,
|
||||
Constant,
|
||||
ConditionalExpression,
|
||||
ConditionalStatement,
|
||||
DeclareVariable,
|
||||
ExpressionStatement,
|
||||
|
||||
@@ -217,6 +217,20 @@ namespace Nz
|
||||
static inline std::shared_ptr<Cast> Build(BasicType castTo, ExpressionPtr* expressions, std::size_t expressionCount);
|
||||
};
|
||||
|
||||
struct NAZARA_SHADER_API ConditionalExpression : public Expression
|
||||
{
|
||||
inline ConditionalExpression();
|
||||
|
||||
ShaderExpressionType GetExpressionType() const override;
|
||||
void Visit(ShaderAstVisitor& visitor) override;
|
||||
|
||||
std::string conditionName;
|
||||
ExpressionPtr falsePath;
|
||||
ExpressionPtr truePath;
|
||||
|
||||
static inline std::shared_ptr<ConditionalExpression> Build(std::string condition, ExpressionPtr truePath, ExpressionPtr falsePath);
|
||||
};
|
||||
|
||||
struct NAZARA_SHADER_API Constant : public Expression
|
||||
{
|
||||
inline Constant();
|
||||
|
||||
@@ -263,6 +263,20 @@ namespace Nz::ShaderNodes
|
||||
return node;
|
||||
}
|
||||
|
||||
inline ConditionalExpression::ConditionalExpression() :
|
||||
Expression(NodeType::ConditionalExpression)
|
||||
{
|
||||
}
|
||||
|
||||
inline std::shared_ptr<ConditionalExpression> ShaderNodes::ConditionalExpression::Build(std::string condition, ExpressionPtr truePath, ExpressionPtr falsePath)
|
||||
{
|
||||
auto node = std::make_shared<ConditionalExpression>();
|
||||
node->conditionName = std::move(condition);
|
||||
node->falsePath = std::move(falsePath);
|
||||
node->truePath = std::move(truePath);
|
||||
|
||||
return node;
|
||||
}
|
||||
|
||||
inline Constant::Constant() :
|
||||
Expression(NodeType::Constant)
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
#include <Nazara/Prerequisites.hpp>
|
||||
#include <Nazara/Shader/Config.hpp>
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
|
||||
namespace Nz
|
||||
{
|
||||
@@ -18,12 +19,17 @@ namespace Nz
|
||||
class NAZARA_SHADER_API ShaderWriter
|
||||
{
|
||||
public:
|
||||
struct States;
|
||||
|
||||
ShaderWriter() = default;
|
||||
ShaderWriter(const ShaderWriter&) = default;
|
||||
ShaderWriter(ShaderWriter&&) = default;
|
||||
virtual ~ShaderWriter();
|
||||
|
||||
virtual std::string Generate(const ShaderAst& shader) = 0;
|
||||
struct States
|
||||
{
|
||||
std::unordered_set<std::string> enabledConditions;
|
||||
};
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
@@ -32,6 +32,8 @@ namespace Nz
|
||||
void Visit(ShaderNodes::AssignOp& node) override;
|
||||
void Visit(ShaderNodes::BinaryOp& node) override;
|
||||
void Visit(ShaderNodes::Cast& node) override;
|
||||
void Visit(ShaderNodes::ConditionalExpression& node) override;
|
||||
void Visit(ShaderNodes::ConditionalStatement& node) override;
|
||||
void Visit(ShaderNodes::Constant& node) override;
|
||||
void Visit(ShaderNodes::DeclareVariable& node) override;
|
||||
void Visit(ShaderNodes::ExpressionStatement& node) override;
|
||||
|
||||
@@ -23,7 +23,7 @@ namespace Nz
|
||||
{
|
||||
class SpirvSection;
|
||||
|
||||
class NAZARA_SHADER_API SpirvWriter
|
||||
class NAZARA_SHADER_API SpirvWriter : public ShaderWriter
|
||||
{
|
||||
friend class SpirvAstVisitor;
|
||||
friend class SpirvExpressionLoad;
|
||||
@@ -38,7 +38,7 @@ namespace Nz
|
||||
SpirvWriter(SpirvWriter&&) = delete;
|
||||
~SpirvWriter() = default;
|
||||
|
||||
std::vector<UInt32> Generate(const ShaderAst& shader);
|
||||
std::vector<UInt32> Generate(const ShaderAst& shader, const States& conditions = {});
|
||||
|
||||
void SetEnv(Environment environment);
|
||||
|
||||
@@ -66,6 +66,8 @@ namespace Nz
|
||||
UInt32 GetPointerTypeId(const ShaderExpressionType& type, SpirvStorageClass storageClass) const;
|
||||
UInt32 GetTypeId(const ShaderExpressionType& type) const;
|
||||
|
||||
inline bool IsConditionEnabled(const std::string& condition) const;
|
||||
|
||||
UInt32 ReadInputVariable(const std::string& name);
|
||||
std::optional<UInt32> ReadInputVariable(const std::string& name, OnlyCache);
|
||||
UInt32 ReadLocalVariable(const std::string& name);
|
||||
@@ -88,6 +90,7 @@ namespace Nz
|
||||
{
|
||||
const ShaderAst* shader = nullptr;
|
||||
const ShaderAst::Function* currentFunction = nullptr;
|
||||
const States* states = nullptr;
|
||||
};
|
||||
|
||||
struct ExtVar
|
||||
|
||||
@@ -7,6 +7,10 @@
|
||||
|
||||
namespace Nz
|
||||
{
|
||||
inline bool SpirvWriter::IsConditionEnabled(const std::string& condition) const
|
||||
{
|
||||
return m_context.states->enabledConditions.find(condition) != m_context.states->enabledConditions.end();
|
||||
}
|
||||
}
|
||||
|
||||
#include <Nazara/Shader/DebugOff.hpp>
|
||||
|
||||
Reference in New Issue
Block a user