Add conditional expression/statement support for shaders

This commit is contained in:
Jérôme Leclercq
2020-11-19 13:56:54 +01:00
parent ad88561245
commit 960817a1f1
45 changed files with 996 additions and 56 deletions

View File

@@ -48,12 +48,13 @@ namespace Nz
{
}
std::string GlslWriter::Generate(const ShaderAst& shader)
std::string GlslWriter::Generate(const ShaderAst& shader, const States& conditions)
{
std::string error;
if (!ValidateShader(shader, &error))
throw std::runtime_error("Invalid shader AST: " + error);
m_context.states = &conditions;
m_context.shader = &shader;
State state;
@@ -461,6 +462,21 @@ namespace Nz
Append(")");
}
void GlslWriter::Visit(ShaderNodes::ConditionalExpression& node)
{
if (m_context.states->enabledConditions.count(node.conditionName) != 0)
Visit(node.truePath);
else
Visit(node.falsePath);
}
void GlslWriter::Visit(ShaderNodes::ConditionalStatement& node)
{
if (m_context.states->enabledConditions.count(node.conditionName) != 0)
Visit(node.statement);
}
void GlslWriter::Visit(ShaderNodes::Constant& node)
{
std::visit([&](auto&& arg)

View File

@@ -7,6 +7,12 @@
namespace Nz
{
void ShaderAst::AddCondition(std::string name)
{
auto& conditionEntry = m_conditions.emplace_back();
conditionEntry.name = std::move(name);
}
void ShaderAst::AddFunction(std::string name, ShaderNodes::StatementPtr statement, std::vector<FunctionParameter> parameters, ShaderNodes::BasicType returnType)
{
auto& functionEntry = m_functions.emplace_back();

View File

@@ -91,6 +91,16 @@ namespace Nz
PushExpression(ShaderNodes::Cast::Build(node.exprType, expressions.data(), expressionCount));
}
void ShaderAstCloner::Visit(ShaderNodes::ConditionalExpression& node)
{
PushExpression(ShaderNodes::ConditionalExpression::Build(node.conditionName, CloneExpression(node.truePath), CloneExpression(node.falsePath)));
}
void ShaderAstCloner::Visit(ShaderNodes::ConditionalStatement& node)
{
PushStatement(ShaderNodes::ConditionalStatement::Build(node.conditionName, CloneStatement(node.statement)));
}
void ShaderAstCloner::Visit(ShaderNodes::Constant& node)
{
PushExpression(ShaderNodes::Constant::Build(node.value));

View File

@@ -47,6 +47,17 @@ namespace Nz
}
}
void ShaderAstRecursiveVisitor::Visit(ShaderNodes::ConditionalExpression& node)
{
Visit(node.truePath);
Visit(node.falsePath);
}
void ShaderAstRecursiveVisitor::Visit(ShaderNodes::ConditionalStatement& node)
{
Visit(node.statement);
}
void ShaderAstRecursiveVisitor::Visit(ShaderNodes::Constant& /*node*/)
{
/* Nothing to do */

View File

@@ -47,6 +47,16 @@ namespace Nz
Serialize(node);
}
void Visit(ShaderNodes::ConditionalExpression& node) override
{
Serialize(node);
}
void Visit(ShaderNodes::ConditionalStatement& node) override
{
Serialize(node);
}
void Visit(ShaderNodes::Constant& node) override
{
Serialize(node);
@@ -179,6 +189,19 @@ namespace Nz
Node(expr);
}
void ShaderAstSerializerBase::Serialize(ShaderNodes::ConditionalExpression& node)
{
Value(node.conditionName);
Node(node.truePath);
Node(node.falsePath);
}
void ShaderAstSerializerBase::Serialize(ShaderNodes::ConditionalStatement& node)
{
Value(node.conditionName);
Node(node.statement);
}
void ShaderAstSerializerBase::Serialize(ShaderNodes::Constant& node)
{
UInt32 typeIndex;
@@ -306,6 +329,12 @@ namespace Nz
}
};
// Conditions
m_stream << UInt32(shader.GetConditionCount());
for (const auto& cond : shader.GetConditions())
m_stream << cond.name;
// Structs
m_stream << UInt32(shader.GetStructCount());
for (const auto& s : shader.GetStructs())
{
@@ -318,9 +347,11 @@ namespace Nz
}
}
// Inputs / Outputs
SerializeInputOutput(shader.GetInputs());
SerializeInputOutput(shader.GetOutputs());
// Uniforms
m_stream << UInt32(shader.GetUniformCount());
for (const auto& uniform : shader.GetUniforms())
{
@@ -336,6 +367,7 @@ namespace Nz
m_stream << UInt32(uniform.memoryLayout.value());
}
// Functions
m_stream << UInt32(shader.GetFunctionCount());
for (const auto& func : shader.GetFunctions())
{
@@ -495,6 +527,18 @@ namespace Nz
ShaderAst shader(static_cast<ShaderStageType>(shaderStage));
// Conditions
UInt32 conditionCount;
m_stream >> conditionCount;
for (UInt32 i = 0; i < conditionCount; ++i)
{
std::string conditionName;
Value(conditionName);
shader.AddCondition(std::move(conditionName));
}
// Structs
UInt32 structCount;
m_stream >> structCount;
for (UInt32 i = 0; i < structCount; ++i)
@@ -514,6 +558,7 @@ namespace Nz
shader.AddStruct(std::move(structName), std::move(members));
}
// Inputs
UInt32 inputCount;
m_stream >> inputCount;
for (UInt32 i = 0; i < inputCount; ++i)
@@ -529,6 +574,7 @@ namespace Nz
shader.AddInput(std::move(inputName), std::move(inputType), location);
}
// Outputs
UInt32 outputCount;
m_stream >> outputCount;
for (UInt32 i = 0; i < outputCount; ++i)
@@ -544,6 +590,7 @@ namespace Nz
shader.AddOutput(std::move(outputName), std::move(outputType), location);
}
// Uniforms
UInt32 uniformCount;
m_stream >> uniformCount;
for (UInt32 i = 0; i < uniformCount; ++i)
@@ -561,6 +608,7 @@ namespace Nz
shader.AddUniform(std::move(name), std::move(type), std::move(binding), std::move(memLayout));
}
// Functions
UInt32 funcCount;
m_stream >> funcCount;
for (UInt32 i = 0; i < funcCount; ++i)
@@ -614,6 +662,7 @@ namespace Nz
HandleType(Branch);
HandleType(Cast);
HandleType(Constant);
HandleType(ConditionalExpression);
HandleType(ConditionalStatement);
HandleType(DeclareVariable);
HandleType(ExpressionStatement);

View File

@@ -241,6 +241,35 @@ namespace Nz
ShaderAstRecursiveVisitor::Visit(node);
}
void ShaderAstValidator::Visit(ShaderNodes::ConditionalExpression& node)
{
MandatoryNode(node.truePath);
MandatoryNode(node.falsePath);
for (std::size_t i = 0; i < m_shader.GetConditionCount(); ++i)
{
const auto& condition = m_shader.GetCondition(i);
if (condition.name == node.conditionName)
return;
}
throw AstError{ "Condition not found" };
}
void ShaderAstValidator::Visit(ShaderNodes::ConditionalStatement& node)
{
MandatoryNode(node.statement);
for (std::size_t i = 0; i < m_shader.GetConditionCount(); ++i)
{
const auto& condition = m_shader.GetCondition(i);
if (condition.name == node.conditionName)
return;
}
throw AstError{ "Condition not found" };
}
void ShaderAstValidator::Visit(ShaderNodes::Constant& /*node*/)
{
}

View File

@@ -9,19 +9,6 @@ namespace Nz
{
ShaderAstVisitor::~ShaderAstVisitor() = default;
void ShaderAstVisitor::EnableCondition(const std::string& name, bool cond)
{
if (cond)
m_conditions.insert(name);
else
m_conditions.erase(name);
}
bool ShaderAstVisitor::IsConditionEnabled(const std::string& name) const
{
return m_conditions.count(name) != 0;
}
void ShaderAstVisitor::Visit(const ShaderNodes::NodePtr& node)
{
node->Visit(*this);

View File

@@ -33,6 +33,16 @@ namespace Nz
throw std::runtime_error("unhandled Cast node");
}
void ShaderAstVisitorExcept::Visit(ShaderNodes::ConditionalExpression& /*node*/)
{
throw std::runtime_error("unhandled ConditionalExpression node");
}
void ShaderAstVisitorExcept::Visit(ShaderNodes::ConditionalStatement& /*node*/)
{
throw std::runtime_error("unhandled ConditionalStatement node");
}
void ShaderAstVisitorExcept::Visit(ShaderNodes::Constant& /*node*/)
{
throw std::runtime_error("unhandled Constant node");

View File

@@ -26,8 +26,7 @@ namespace Nz::ShaderNodes
void ConditionalStatement::Visit(ShaderAstVisitor& visitor)
{
if (visitor.IsConditionEnabled(conditionName))
statement->Visit(visitor);
visitor.Visit(*this);
}
@@ -204,6 +203,18 @@ namespace Nz::ShaderNodes
}
ShaderExpressionType ConditionalExpression::GetExpressionType() const
{
assert(truePath->GetExpressionType() == falsePath->GetExpressionType());
return truePath->GetExpressionType();
}
void ConditionalExpression::Visit(ShaderAstVisitor& visitor)
{
visitor.Visit(*this);
}
ExpressionCategory SwizzleOp::GetExpressionCategory() const
{
return expression->GetExpressionCategory();

View File

@@ -313,6 +313,20 @@ namespace Nz
PushResultId(resultId);
}
void SpirvAstVisitor::Visit(ShaderNodes::ConditionalExpression& node)
{
if (m_writer.IsConditionEnabled(node.conditionName))
Visit(node.truePath);
else
Visit(node.falsePath);
}
void SpirvAstVisitor::Visit(ShaderNodes::ConditionalStatement& node)
{
if (m_writer.IsConditionEnabled(node.conditionName))
Visit(node.statement);
}
void SpirvAstVisitor::Visit(ShaderNodes::Constant& node)
{
std::visit([&] (const auto& value)

View File

@@ -33,7 +33,8 @@ namespace Nz
using LocalContainer = std::unordered_set<std::shared_ptr<const ShaderNodes::LocalVariable>>;
using ParameterContainer = std::unordered_set< std::shared_ptr<const ShaderNodes::ParameterVariable>>;
PreVisitor(SpirvConstantCache& constantCache) :
PreVisitor(const SpirvWriter::States& conditions, SpirvConstantCache& constantCache) :
m_conditions(conditions),
m_constantCache(constantCache)
{
}
@@ -49,6 +50,20 @@ namespace Nz
ShaderAstRecursiveVisitor::Visit(node);
}
void Visit(ShaderNodes::ConditionalExpression& node) override
{
if (m_conditions.enabledConditions.count(node.conditionName) != 0)
Visit(node.truePath);
else
Visit(node.falsePath);
}
void Visit(ShaderNodes::ConditionalStatement& node) override
{
if (m_conditions.enabledConditions.count(node.conditionName) != 0)
Visit(node.statement);
}
void Visit(ShaderNodes::Constant& node) override
{
std::visit([&](auto&& arg)
@@ -126,6 +141,7 @@ namespace Nz
ParameterContainer paramVars;
private:
const SpirvWriter::States& m_conditions;
SpirvConstantCache& m_constantCache;
};
@@ -193,13 +209,14 @@ namespace Nz
{
}
std::vector<UInt32> SpirvWriter::Generate(const ShaderAst& shader)
std::vector<UInt32> SpirvWriter::Generate(const ShaderAst& shader, const States& conditions)
{
std::string error;
if (!ValidateShader(shader, &error))
throw std::runtime_error("Invalid shader AST: " + error);
m_context.shader = &shader;
m_context.states = &conditions;
State state;
m_currentState = &state;
@@ -212,7 +229,7 @@ namespace Nz
ShaderAstCloner cloner;
PreVisitor preVisitor(state.constantTypeCache);
PreVisitor preVisitor(conditions, state.constantTypeCache);
for (const auto& func : shader.GetFunctions())
{
functionStatements.emplace_back(cloner.Clone(func.statement));
@@ -450,7 +467,7 @@ namespace Nz
m_environment = std::move(environment);
}
UInt32 Nz::SpirvWriter::AllocateResultId()
UInt32 SpirvWriter::AllocateResultId()
{
return m_currentState->nextVarIndex++;
}