Shader: Add support for numerical fors

This commit is contained in:
Jérôme Leclercq
2022-01-06 20:38:55 +01:00
parent 972d5ffd3f
commit 756fd773a9
24 changed files with 746 additions and 134 deletions

View File

@@ -83,6 +83,13 @@ namespace Nz
inline bool HasValue() const { return setIndex.HasValue(); }
};
struct LangWriter::UnrollAttribute
{
const ShaderAst::AttributeValue<ShaderAst::LoopUnroll>& unroll;
inline bool HasValue() const { return unroll.HasValue(); }
};
struct LangWriter::State
{
const States* states = nullptr;
@@ -103,10 +110,7 @@ namespace Nz
m_currentState = nullptr;
});
ShaderAst::SanitizeVisitor::Options options;
options.removeOptionDeclaration = false;
ShaderAst::StatementPtr sanitizedAst = ShaderAst::Sanitize(shader, options);
ShaderAst::StatementPtr sanitizedAst = ShaderAst::Sanitize(shader);
AppendHeader();
@@ -277,10 +281,14 @@ namespace Nz
if (!binding.HasValue())
return;
Append("binding(");
if (binding.bindingIndex.IsResultingValue())
Append("binding(", binding.bindingIndex.GetResultingValue(), ")");
Append(binding.bindingIndex.GetResultingValue());
else
binding.bindingIndex.GetExpression()->Visit(*this);
Append(")");
}
void LangWriter::AppendAttribute(BuiltinAttribute builtin)
@@ -288,25 +296,29 @@ namespace Nz
if (!builtin.HasValue())
return;
Append("builtin(");
if (builtin.builtin.IsResultingValue())
{
switch (builtin.builtin.GetResultingValue())
{
case ShaderAst::BuiltinEntry::FragCoord:
Append("builtin(fragcoord)");
Append("fragcoord");
break;
case ShaderAst::BuiltinEntry::FragDepth:
Append("builtin(fragdepth)");
Append("fragdepth");
break;
case ShaderAst::BuiltinEntry::VertexPosition:
Append("builtin(position)");
Append("position");
break;
}
}
else
builtin.builtin.GetExpression()->Visit(*this);
Append(")");
}
void LangWriter::AppendAttribute(DepthWriteAttribute depthWrite)
@@ -314,29 +326,33 @@ namespace Nz
if (!depthWrite.HasValue())
return;
Append("depth_write(");
if (depthWrite.writeMode.IsResultingValue())
{
switch (depthWrite.writeMode.GetResultingValue())
{
case ShaderAst::DepthWriteMode::Greater:
Append("depth_write(greater)");
Append("greater");
break;
case ShaderAst::DepthWriteMode::Less:
Append("depth_write(less)");
Append("less");
break;
case ShaderAst::DepthWriteMode::Replace:
Append("depth_write(replace)");
Append("replace");
break;
case ShaderAst::DepthWriteMode::Unchanged:
Append("depth_write(unchanged)");
Append("unchanged");
break;
}
}
else
depthWrite.writeMode.GetExpression()->Visit(*this);
Append(")");
}
void LangWriter::AppendAttribute(EarlyFragmentTestsAttribute earlyFragmentTests)
@@ -344,15 +360,19 @@ namespace Nz
if (!earlyFragmentTests.HasValue())
return;
Append("early_fragment_tests(");
if (earlyFragmentTests.earlyFragmentTests.IsResultingValue())
{
if (earlyFragmentTests.earlyFragmentTests.GetResultingValue())
Append("early_fragment_tests(true)");
Append("true");
else
Append("early_fragment_tests(false)");
Append("false");
}
else
earlyFragmentTests.earlyFragmentTests.GetExpression()->Visit(*this);
Append(")");
}
void LangWriter::AppendAttribute(EntryAttribute entry)
@@ -360,21 +380,25 @@ namespace Nz
if (!entry.HasValue())
return;
Append("entry(");
if (entry.stageType.IsResultingValue())
{
switch (entry.stageType.GetResultingValue())
{
case ShaderStageType::Fragment:
Append("entry(frag)");
Append("frag");
break;
case ShaderStageType::Vertex:
Append("entry(vert)");
Append("vert");
break;
}
}
else
entry.stageType.GetExpression()->Visit(*this);
Append(")");
}
void LangWriter::AppendAttribute(LayoutAttribute entry)
@@ -382,17 +406,19 @@ namespace Nz
if (!entry.HasValue())
return;
Append("layout(");
if (entry.layout.IsResultingValue())
{
switch (entry.layout.GetResultingValue())
{
case StructLayout::Std140:
Append("layout(std140)");
Append("std140");
break;
}
}
else
entry.layout.GetExpression()->Visit(*this);
Append(")");
}
void LangWriter::AppendAttribute(LocationAttribute location)
@@ -400,10 +426,14 @@ namespace Nz
if (!location.HasValue())
return;
Append("location(");
if (location.locationIndex.IsResultingValue())
Append("location(", location.locationIndex.GetResultingValue(), ")");
Append(location.locationIndex.GetResultingValue());
else
location.locationIndex.GetExpression()->Visit(*this);
Append(")");
}
void LangWriter::AppendAttribute(SetAttribute set)
@@ -411,10 +441,45 @@ namespace Nz
if (!set.HasValue())
return;
Append("set(");
if (set.setIndex.IsResultingValue())
Append("set(", set.setIndex.GetResultingValue(), ")");
Append(set.setIndex.GetResultingValue());
else
set.setIndex.GetExpression()->Visit(*this);
Append(")");
}
void LangWriter::AppendAttribute(UnrollAttribute unroll)
{
if (!unroll.HasValue())
return;
Append("unroll(");
if (unroll.unroll.IsResultingValue())
{
switch (unroll.unroll.GetResultingValue())
{
case ShaderAst::LoopUnroll::Always:
Append("always");
break;
case ShaderAst::LoopUnroll::Hint:
Append("hint");
break;
case ShaderAst::LoopUnroll::Never:
Append("never");
break;
default:
break;
}
}
else
unroll.unroll.GetExpression()->Visit(*this);
}
void LangWriter::AppendCommentSection(const std::string& section)
@@ -426,26 +491,6 @@ namespace Nz
AppendLine();
}
void LangWriter::AppendField(std::size_t structIndex, const ShaderAst::ExpressionPtr* memberIndices, std::size_t remainingMembers)
{
ShaderAst::StructDescription* structDesc = Retrieve(m_currentState->structs, structIndex);
assert((*memberIndices)->GetType() == ShaderAst::NodeType::ConstantValueExpression);
auto& constantValue = static_cast<ShaderAst::ConstantValueExpression&>(**memberIndices);
Int32 index = std::get<Int32>(constantValue.value);
const auto& member = structDesc->members[index];
Append(".");
Append(member.name);
if (remainingMembers > 1)
{
assert(IsStructType(member.type));
AppendField(std::get<ShaderAst::StructType>(member.type).structIndex, memberIndices + 1, remainingMembers - 1);
}
}
void LangWriter::AppendLine(const std::string& txt)
{
NazaraAssert(m_currentState, "This function should only be called while processing an AST");
@@ -526,24 +571,30 @@ namespace Nz
Append(")");
}
void LangWriter::Visit(ShaderAst::AccessIdentifierExpression& node)
{
Visit(node.expr, true);
const ShaderAst::ExpressionType& exprType = GetExpressionType(*node.expr);
assert(IsStructType(exprType));
for (const std::string& identifier : node.identifiers)
Append(".", identifier);
}
void LangWriter::Visit(ShaderAst::AccessIndexExpression& node)
{
Visit(node.expr, true);
const ShaderAst::ExpressionType& exprType = GetExpressionType(*node.expr);
assert(!IsStructType(exprType));
// For structs, convert indices to field names
if (IsStructType(exprType))
AppendField(std::get<ShaderAst::StructType>(exprType).structIndex, node.indices.data(), node.indices.size());
else
// Array access
for (ShaderAst::ExpressionPtr& expr : node.indices)
{
// Array access
for (ShaderAst::ExpressionPtr& expr : node.indices)
{
Append("[");
Visit(expr);
Append("]");
}
Append("[");
expr->Visit(*this);
Append("]");
}
}
@@ -826,11 +877,36 @@ namespace Nz
Append(";");
}
void LangWriter::Visit(ShaderAst::ForStatement& node)
{
assert(node.varIndex);
RegisterVariable(*node.varIndex, node.varName);
AppendAttributes(true, UnrollAttribute{ node.unroll });
Append("for ", node.varName, " in ");
node.fromExpr->Visit(*this);
Append(" -> ");
node.toExpr->Visit(*this);
if (node.stepExpr)
{
Append(" : ");
node.stepExpr->Visit(*this);
}
AppendLine();
EnterScope();
node.statement->Visit(*this);
LeaveScope();
}
void LangWriter::Visit(ShaderAst::ForEachStatement& node)
{
assert(node.varIndex);
RegisterVariable(*node.varIndex, node.varName);
AppendAttributes(true, UnrollAttribute{ node.unroll });
Append("for ", node.varName, " in ");
node.expression->Visit(*this);
AppendLine();