Add conditional expression/statement support for shaders
This commit is contained in:
@@ -48,12 +48,13 @@ namespace Nz
|
||||
{
|
||||
}
|
||||
|
||||
std::string GlslWriter::Generate(const ShaderAst& shader)
|
||||
std::string GlslWriter::Generate(const ShaderAst& shader, const States& conditions)
|
||||
{
|
||||
std::string error;
|
||||
if (!ValidateShader(shader, &error))
|
||||
throw std::runtime_error("Invalid shader AST: " + error);
|
||||
|
||||
m_context.states = &conditions;
|
||||
m_context.shader = &shader;
|
||||
|
||||
State state;
|
||||
@@ -461,6 +462,21 @@ namespace Nz
|
||||
Append(")");
|
||||
}
|
||||
|
||||
|
||||
void GlslWriter::Visit(ShaderNodes::ConditionalExpression& node)
|
||||
{
|
||||
if (m_context.states->enabledConditions.count(node.conditionName) != 0)
|
||||
Visit(node.truePath);
|
||||
else
|
||||
Visit(node.falsePath);
|
||||
}
|
||||
|
||||
void GlslWriter::Visit(ShaderNodes::ConditionalStatement& node)
|
||||
{
|
||||
if (m_context.states->enabledConditions.count(node.conditionName) != 0)
|
||||
Visit(node.statement);
|
||||
}
|
||||
|
||||
void GlslWriter::Visit(ShaderNodes::Constant& node)
|
||||
{
|
||||
std::visit([&](auto&& arg)
|
||||
|
||||
@@ -7,6 +7,12 @@
|
||||
|
||||
namespace Nz
|
||||
{
|
||||
void ShaderAst::AddCondition(std::string name)
|
||||
{
|
||||
auto& conditionEntry = m_conditions.emplace_back();
|
||||
conditionEntry.name = std::move(name);
|
||||
}
|
||||
|
||||
void ShaderAst::AddFunction(std::string name, ShaderNodes::StatementPtr statement, std::vector<FunctionParameter> parameters, ShaderNodes::BasicType returnType)
|
||||
{
|
||||
auto& functionEntry = m_functions.emplace_back();
|
||||
|
||||
@@ -91,6 +91,16 @@ namespace Nz
|
||||
PushExpression(ShaderNodes::Cast::Build(node.exprType, expressions.data(), expressionCount));
|
||||
}
|
||||
|
||||
void ShaderAstCloner::Visit(ShaderNodes::ConditionalExpression& node)
|
||||
{
|
||||
PushExpression(ShaderNodes::ConditionalExpression::Build(node.conditionName, CloneExpression(node.truePath), CloneExpression(node.falsePath)));
|
||||
}
|
||||
|
||||
void ShaderAstCloner::Visit(ShaderNodes::ConditionalStatement& node)
|
||||
{
|
||||
PushStatement(ShaderNodes::ConditionalStatement::Build(node.conditionName, CloneStatement(node.statement)));
|
||||
}
|
||||
|
||||
void ShaderAstCloner::Visit(ShaderNodes::Constant& node)
|
||||
{
|
||||
PushExpression(ShaderNodes::Constant::Build(node.value));
|
||||
|
||||
@@ -47,6 +47,17 @@ namespace Nz
|
||||
}
|
||||
}
|
||||
|
||||
void ShaderAstRecursiveVisitor::Visit(ShaderNodes::ConditionalExpression& node)
|
||||
{
|
||||
Visit(node.truePath);
|
||||
Visit(node.falsePath);
|
||||
}
|
||||
|
||||
void ShaderAstRecursiveVisitor::Visit(ShaderNodes::ConditionalStatement& node)
|
||||
{
|
||||
Visit(node.statement);
|
||||
}
|
||||
|
||||
void ShaderAstRecursiveVisitor::Visit(ShaderNodes::Constant& /*node*/)
|
||||
{
|
||||
/* Nothing to do */
|
||||
|
||||
@@ -47,6 +47,16 @@ namespace Nz
|
||||
Serialize(node);
|
||||
}
|
||||
|
||||
void Visit(ShaderNodes::ConditionalExpression& node) override
|
||||
{
|
||||
Serialize(node);
|
||||
}
|
||||
|
||||
void Visit(ShaderNodes::ConditionalStatement& node) override
|
||||
{
|
||||
Serialize(node);
|
||||
}
|
||||
|
||||
void Visit(ShaderNodes::Constant& node) override
|
||||
{
|
||||
Serialize(node);
|
||||
@@ -179,6 +189,19 @@ namespace Nz
|
||||
Node(expr);
|
||||
}
|
||||
|
||||
void ShaderAstSerializerBase::Serialize(ShaderNodes::ConditionalExpression& node)
|
||||
{
|
||||
Value(node.conditionName);
|
||||
Node(node.truePath);
|
||||
Node(node.falsePath);
|
||||
}
|
||||
|
||||
void ShaderAstSerializerBase::Serialize(ShaderNodes::ConditionalStatement& node)
|
||||
{
|
||||
Value(node.conditionName);
|
||||
Node(node.statement);
|
||||
}
|
||||
|
||||
void ShaderAstSerializerBase::Serialize(ShaderNodes::Constant& node)
|
||||
{
|
||||
UInt32 typeIndex;
|
||||
@@ -306,6 +329,12 @@ namespace Nz
|
||||
}
|
||||
};
|
||||
|
||||
// Conditions
|
||||
m_stream << UInt32(shader.GetConditionCount());
|
||||
for (const auto& cond : shader.GetConditions())
|
||||
m_stream << cond.name;
|
||||
|
||||
// Structs
|
||||
m_stream << UInt32(shader.GetStructCount());
|
||||
for (const auto& s : shader.GetStructs())
|
||||
{
|
||||
@@ -318,9 +347,11 @@ namespace Nz
|
||||
}
|
||||
}
|
||||
|
||||
// Inputs / Outputs
|
||||
SerializeInputOutput(shader.GetInputs());
|
||||
SerializeInputOutput(shader.GetOutputs());
|
||||
|
||||
// Uniforms
|
||||
m_stream << UInt32(shader.GetUniformCount());
|
||||
for (const auto& uniform : shader.GetUniforms())
|
||||
{
|
||||
@@ -336,6 +367,7 @@ namespace Nz
|
||||
m_stream << UInt32(uniform.memoryLayout.value());
|
||||
}
|
||||
|
||||
// Functions
|
||||
m_stream << UInt32(shader.GetFunctionCount());
|
||||
for (const auto& func : shader.GetFunctions())
|
||||
{
|
||||
@@ -495,6 +527,18 @@ namespace Nz
|
||||
|
||||
ShaderAst shader(static_cast<ShaderStageType>(shaderStage));
|
||||
|
||||
// Conditions
|
||||
UInt32 conditionCount;
|
||||
m_stream >> conditionCount;
|
||||
for (UInt32 i = 0; i < conditionCount; ++i)
|
||||
{
|
||||
std::string conditionName;
|
||||
Value(conditionName);
|
||||
|
||||
shader.AddCondition(std::move(conditionName));
|
||||
}
|
||||
|
||||
// Structs
|
||||
UInt32 structCount;
|
||||
m_stream >> structCount;
|
||||
for (UInt32 i = 0; i < structCount; ++i)
|
||||
@@ -514,6 +558,7 @@ namespace Nz
|
||||
shader.AddStruct(std::move(structName), std::move(members));
|
||||
}
|
||||
|
||||
// Inputs
|
||||
UInt32 inputCount;
|
||||
m_stream >> inputCount;
|
||||
for (UInt32 i = 0; i < inputCount; ++i)
|
||||
@@ -529,6 +574,7 @@ namespace Nz
|
||||
shader.AddInput(std::move(inputName), std::move(inputType), location);
|
||||
}
|
||||
|
||||
// Outputs
|
||||
UInt32 outputCount;
|
||||
m_stream >> outputCount;
|
||||
for (UInt32 i = 0; i < outputCount; ++i)
|
||||
@@ -544,6 +590,7 @@ namespace Nz
|
||||
shader.AddOutput(std::move(outputName), std::move(outputType), location);
|
||||
}
|
||||
|
||||
// Uniforms
|
||||
UInt32 uniformCount;
|
||||
m_stream >> uniformCount;
|
||||
for (UInt32 i = 0; i < uniformCount; ++i)
|
||||
@@ -561,6 +608,7 @@ namespace Nz
|
||||
shader.AddUniform(std::move(name), std::move(type), std::move(binding), std::move(memLayout));
|
||||
}
|
||||
|
||||
// Functions
|
||||
UInt32 funcCount;
|
||||
m_stream >> funcCount;
|
||||
for (UInt32 i = 0; i < funcCount; ++i)
|
||||
@@ -614,6 +662,7 @@ namespace Nz
|
||||
HandleType(Branch);
|
||||
HandleType(Cast);
|
||||
HandleType(Constant);
|
||||
HandleType(ConditionalExpression);
|
||||
HandleType(ConditionalStatement);
|
||||
HandleType(DeclareVariable);
|
||||
HandleType(ExpressionStatement);
|
||||
|
||||
@@ -241,6 +241,35 @@ namespace Nz
|
||||
ShaderAstRecursiveVisitor::Visit(node);
|
||||
}
|
||||
|
||||
void ShaderAstValidator::Visit(ShaderNodes::ConditionalExpression& node)
|
||||
{
|
||||
MandatoryNode(node.truePath);
|
||||
MandatoryNode(node.falsePath);
|
||||
|
||||
for (std::size_t i = 0; i < m_shader.GetConditionCount(); ++i)
|
||||
{
|
||||
const auto& condition = m_shader.GetCondition(i);
|
||||
if (condition.name == node.conditionName)
|
||||
return;
|
||||
}
|
||||
|
||||
throw AstError{ "Condition not found" };
|
||||
}
|
||||
|
||||
void ShaderAstValidator::Visit(ShaderNodes::ConditionalStatement& node)
|
||||
{
|
||||
MandatoryNode(node.statement);
|
||||
|
||||
for (std::size_t i = 0; i < m_shader.GetConditionCount(); ++i)
|
||||
{
|
||||
const auto& condition = m_shader.GetCondition(i);
|
||||
if (condition.name == node.conditionName)
|
||||
return;
|
||||
}
|
||||
|
||||
throw AstError{ "Condition not found" };
|
||||
}
|
||||
|
||||
void ShaderAstValidator::Visit(ShaderNodes::Constant& /*node*/)
|
||||
{
|
||||
}
|
||||
|
||||
@@ -9,19 +9,6 @@ namespace Nz
|
||||
{
|
||||
ShaderAstVisitor::~ShaderAstVisitor() = default;
|
||||
|
||||
void ShaderAstVisitor::EnableCondition(const std::string& name, bool cond)
|
||||
{
|
||||
if (cond)
|
||||
m_conditions.insert(name);
|
||||
else
|
||||
m_conditions.erase(name);
|
||||
}
|
||||
|
||||
bool ShaderAstVisitor::IsConditionEnabled(const std::string& name) const
|
||||
{
|
||||
return m_conditions.count(name) != 0;
|
||||
}
|
||||
|
||||
void ShaderAstVisitor::Visit(const ShaderNodes::NodePtr& node)
|
||||
{
|
||||
node->Visit(*this);
|
||||
|
||||
@@ -33,6 +33,16 @@ namespace Nz
|
||||
throw std::runtime_error("unhandled Cast node");
|
||||
}
|
||||
|
||||
void ShaderAstVisitorExcept::Visit(ShaderNodes::ConditionalExpression& /*node*/)
|
||||
{
|
||||
throw std::runtime_error("unhandled ConditionalExpression node");
|
||||
}
|
||||
|
||||
void ShaderAstVisitorExcept::Visit(ShaderNodes::ConditionalStatement& /*node*/)
|
||||
{
|
||||
throw std::runtime_error("unhandled ConditionalStatement node");
|
||||
}
|
||||
|
||||
void ShaderAstVisitorExcept::Visit(ShaderNodes::Constant& /*node*/)
|
||||
{
|
||||
throw std::runtime_error("unhandled Constant node");
|
||||
|
||||
@@ -26,8 +26,7 @@ namespace Nz::ShaderNodes
|
||||
|
||||
void ConditionalStatement::Visit(ShaderAstVisitor& visitor)
|
||||
{
|
||||
if (visitor.IsConditionEnabled(conditionName))
|
||||
statement->Visit(visitor);
|
||||
visitor.Visit(*this);
|
||||
}
|
||||
|
||||
|
||||
@@ -204,6 +203,18 @@ namespace Nz::ShaderNodes
|
||||
}
|
||||
|
||||
|
||||
ShaderExpressionType ConditionalExpression::GetExpressionType() const
|
||||
{
|
||||
assert(truePath->GetExpressionType() == falsePath->GetExpressionType());
|
||||
return truePath->GetExpressionType();
|
||||
}
|
||||
|
||||
void ConditionalExpression::Visit(ShaderAstVisitor& visitor)
|
||||
{
|
||||
visitor.Visit(*this);
|
||||
}
|
||||
|
||||
|
||||
ExpressionCategory SwizzleOp::GetExpressionCategory() const
|
||||
{
|
||||
return expression->GetExpressionCategory();
|
||||
|
||||
@@ -313,6 +313,20 @@ namespace Nz
|
||||
PushResultId(resultId);
|
||||
}
|
||||
|
||||
void SpirvAstVisitor::Visit(ShaderNodes::ConditionalExpression& node)
|
||||
{
|
||||
if (m_writer.IsConditionEnabled(node.conditionName))
|
||||
Visit(node.truePath);
|
||||
else
|
||||
Visit(node.falsePath);
|
||||
}
|
||||
|
||||
void SpirvAstVisitor::Visit(ShaderNodes::ConditionalStatement& node)
|
||||
{
|
||||
if (m_writer.IsConditionEnabled(node.conditionName))
|
||||
Visit(node.statement);
|
||||
}
|
||||
|
||||
void SpirvAstVisitor::Visit(ShaderNodes::Constant& node)
|
||||
{
|
||||
std::visit([&] (const auto& value)
|
||||
|
||||
@@ -33,7 +33,8 @@ namespace Nz
|
||||
using LocalContainer = std::unordered_set<std::shared_ptr<const ShaderNodes::LocalVariable>>;
|
||||
using ParameterContainer = std::unordered_set< std::shared_ptr<const ShaderNodes::ParameterVariable>>;
|
||||
|
||||
PreVisitor(SpirvConstantCache& constantCache) :
|
||||
PreVisitor(const SpirvWriter::States& conditions, SpirvConstantCache& constantCache) :
|
||||
m_conditions(conditions),
|
||||
m_constantCache(constantCache)
|
||||
{
|
||||
}
|
||||
@@ -49,6 +50,20 @@ namespace Nz
|
||||
ShaderAstRecursiveVisitor::Visit(node);
|
||||
}
|
||||
|
||||
void Visit(ShaderNodes::ConditionalExpression& node) override
|
||||
{
|
||||
if (m_conditions.enabledConditions.count(node.conditionName) != 0)
|
||||
Visit(node.truePath);
|
||||
else
|
||||
Visit(node.falsePath);
|
||||
}
|
||||
|
||||
void Visit(ShaderNodes::ConditionalStatement& node) override
|
||||
{
|
||||
if (m_conditions.enabledConditions.count(node.conditionName) != 0)
|
||||
Visit(node.statement);
|
||||
}
|
||||
|
||||
void Visit(ShaderNodes::Constant& node) override
|
||||
{
|
||||
std::visit([&](auto&& arg)
|
||||
@@ -126,6 +141,7 @@ namespace Nz
|
||||
ParameterContainer paramVars;
|
||||
|
||||
private:
|
||||
const SpirvWriter::States& m_conditions;
|
||||
SpirvConstantCache& m_constantCache;
|
||||
};
|
||||
|
||||
@@ -193,13 +209,14 @@ namespace Nz
|
||||
{
|
||||
}
|
||||
|
||||
std::vector<UInt32> SpirvWriter::Generate(const ShaderAst& shader)
|
||||
std::vector<UInt32> SpirvWriter::Generate(const ShaderAst& shader, const States& conditions)
|
||||
{
|
||||
std::string error;
|
||||
if (!ValidateShader(shader, &error))
|
||||
throw std::runtime_error("Invalid shader AST: " + error);
|
||||
|
||||
m_context.shader = &shader;
|
||||
m_context.states = &conditions;
|
||||
|
||||
State state;
|
||||
m_currentState = &state;
|
||||
@@ -212,7 +229,7 @@ namespace Nz
|
||||
|
||||
ShaderAstCloner cloner;
|
||||
|
||||
PreVisitor preVisitor(state.constantTypeCache);
|
||||
PreVisitor preVisitor(conditions, state.constantTypeCache);
|
||||
for (const auto& func : shader.GetFunctions())
|
||||
{
|
||||
functionStatements.emplace_back(cloner.Clone(func.statement));
|
||||
@@ -450,7 +467,7 @@ namespace Nz
|
||||
m_environment = std::move(environment);
|
||||
}
|
||||
|
||||
UInt32 Nz::SpirvWriter::AllocateResultId()
|
||||
UInt32 SpirvWriter::AllocateResultId()
|
||||
{
|
||||
return m_currentState->nextVarIndex++;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user