Shader: Add support for numerical fors

This commit is contained in:
Jérôme Leclercq 2022-01-06 20:38:55 +01:00
parent 972d5ffd3f
commit 756fd773a9
24 changed files with 746 additions and 134 deletions

View File

@ -64,6 +64,7 @@ namespace Nz::ShaderAst
virtual StatementPtr Clone(DeclareVariableStatement& node); virtual StatementPtr Clone(DeclareVariableStatement& node);
virtual StatementPtr Clone(DiscardStatement& node); virtual StatementPtr Clone(DiscardStatement& node);
virtual StatementPtr Clone(ExpressionStatement& node); virtual StatementPtr Clone(ExpressionStatement& node);
virtual StatementPtr Clone(ForStatement& node);
virtual StatementPtr Clone(ForEachStatement& node); virtual StatementPtr Clone(ForEachStatement& node);
virtual StatementPtr Clone(MultiStatement& node); virtual StatementPtr Clone(MultiStatement& node);
virtual StatementPtr Clone(NoOpStatement& node); virtual StatementPtr Clone(NoOpStatement& node);

View File

@ -54,6 +54,7 @@ namespace Nz::ShaderAst
inline bool Compare(const DeclareVariableStatement& lhs, const DeclareVariableStatement& rhs); inline bool Compare(const DeclareVariableStatement& lhs, const DeclareVariableStatement& rhs);
inline bool Compare(const DiscardStatement& lhs, const DiscardStatement& rhs); inline bool Compare(const DiscardStatement& lhs, const DiscardStatement& rhs);
inline bool Compare(const ExpressionStatement& lhs, const ExpressionStatement& rhs); inline bool Compare(const ExpressionStatement& lhs, const ExpressionStatement& rhs);
inline bool Compare(const ForStatement& lhs, const ForStatement& rhs);
inline bool Compare(const ForEachStatement& lhs, const ForEachStatement& rhs); inline bool Compare(const ForEachStatement& lhs, const ForEachStatement& rhs);
inline bool Compare(const MultiStatement& lhs, const MultiStatement& rhs); inline bool Compare(const MultiStatement& lhs, const MultiStatement& rhs);
inline bool Compare(const NoOpStatement& lhs, const NoOpStatement& rhs); inline bool Compare(const NoOpStatement& lhs, const NoOpStatement& rhs);

View File

@ -458,6 +458,29 @@ namespace Nz::ShaderAst
return true; return true;
} }
bool Compare(const ForStatement& lhs, const ForStatement& rhs)
{
if (!Compare(lhs.varName, rhs.varName))
return false;
if (!Compare(lhs.unroll, rhs.unroll))
return false;
if (!Compare(lhs.fromExpr, rhs.fromExpr))
return false;
if (!Compare(lhs.toExpr, rhs.toExpr))
return false;
if (!Compare(lhs.stepExpr, rhs.stepExpr))
return false;
if (!Compare(lhs.statement, rhs.statement))
return false;
return true;
}
bool Compare(const ForEachStatement& lhs, const ForEachStatement& rhs) bool Compare(const ForEachStatement& lhs, const ForEachStatement& rhs)
{ {
if (!Compare(lhs.varName, rhs.varName)) if (!Compare(lhs.varName, rhs.varName))

View File

@ -52,6 +52,7 @@ NAZARA_SHADERAST_STATEMENT(DeclareOptionStatement)
NAZARA_SHADERAST_STATEMENT(DeclareStructStatement) NAZARA_SHADERAST_STATEMENT(DeclareStructStatement)
NAZARA_SHADERAST_STATEMENT(DeclareVariableStatement) NAZARA_SHADERAST_STATEMENT(DeclareVariableStatement)
NAZARA_SHADERAST_STATEMENT(DiscardStatement) NAZARA_SHADERAST_STATEMENT(DiscardStatement)
NAZARA_SHADERAST_STATEMENT(ForStatement)
NAZARA_SHADERAST_STATEMENT(ForEachStatement) NAZARA_SHADERAST_STATEMENT(ForEachStatement)
NAZARA_SHADERAST_STATEMENT(ExpressionStatement) NAZARA_SHADERAST_STATEMENT(ExpressionStatement)
NAZARA_SHADERAST_STATEMENT(MultiStatement) NAZARA_SHADERAST_STATEMENT(MultiStatement)

View File

@ -46,6 +46,7 @@ namespace Nz::ShaderAst
void Visit(DeclareVariableStatement& node) override; void Visit(DeclareVariableStatement& node) override;
void Visit(DiscardStatement& node) override; void Visit(DiscardStatement& node) override;
void Visit(ExpressionStatement& node) override; void Visit(ExpressionStatement& node) override;
void Visit(ForStatement& node) override;
void Visit(ForEachStatement& node) override; void Visit(ForEachStatement& node) override;
void Visit(MultiStatement& node) override; void Visit(MultiStatement& node) override;
void Visit(NoOpStatement& node) override; void Visit(NoOpStatement& node) override;

View File

@ -49,6 +49,7 @@ namespace Nz::ShaderAst
void Serialize(DeclareVariableStatement& node); void Serialize(DeclareVariableStatement& node);
void Serialize(DiscardStatement& node); void Serialize(DiscardStatement& node);
void Serialize(ExpressionStatement& node); void Serialize(ExpressionStatement& node);
void Serialize(ForStatement& node);
void Serialize(ForEachStatement& node); void Serialize(ForEachStatement& node);
void Serialize(MultiStatement& node); void Serialize(MultiStatement& node);
void Serialize(NoOpStatement& node); void Serialize(NoOpStatement& node);

View File

@ -340,6 +340,20 @@ namespace Nz::ShaderAst
ExpressionPtr expression; ExpressionPtr expression;
}; };
struct NAZARA_SHADER_API ForStatement : Statement
{
NodeType GetType() const override;
void Visit(AstStatementVisitor& visitor) override;
AttributeValue<LoopUnroll> unroll;
std::optional<std::size_t> varIndex;
std::string varName;
ExpressionPtr fromExpr;
ExpressionPtr stepExpr;
ExpressionPtr toExpr;
StatementPtr statement;
};
struct NAZARA_SHADER_API ForEachStatement : Statement struct NAZARA_SHADER_API ForEachStatement : Statement
{ {
NodeType GetType() const override; NodeType GetType() const override;

View File

@ -45,6 +45,7 @@ namespace Nz::ShaderAst
bool removeOptionDeclaration = false; bool removeOptionDeclaration = false;
bool removeScalarSwizzling = false; bool removeScalarSwizzling = false;
bool splitMultipleBranches = false; bool splitMultipleBranches = false;
bool useIdentifierAccessesForStructs = true;
}; };
private: private:
@ -78,6 +79,7 @@ namespace Nz::ShaderAst
StatementPtr Clone(DeclareVariableStatement& node) override; StatementPtr Clone(DeclareVariableStatement& node) override;
StatementPtr Clone(DiscardStatement& node) override; StatementPtr Clone(DiscardStatement& node) override;
StatementPtr Clone(ExpressionStatement& node) override; StatementPtr Clone(ExpressionStatement& node) override;
StatementPtr Clone(ForStatement& node) override;
StatementPtr Clone(ForEachStatement& node) override; StatementPtr Clone(ForEachStatement& node) override;
StatementPtr Clone(MultiStatement& node) override; StatementPtr Clone(MultiStatement& node) override;
StatementPtr Clone(WhileStatement& node) override; StatementPtr Clone(WhileStatement& node) override;

View File

@ -66,7 +66,6 @@ namespace Nz
template<typename T1, typename T2, typename... Args> void Append(const T1& firstParam, const T2& secondParam, Args&&... params); template<typename T1, typename T2, typename... Args> void Append(const T1& firstParam, const T2& secondParam, Args&&... params);
void AppendCommentSection(const std::string& section); void AppendCommentSection(const std::string& section);
void AppendFunctionDeclaration(const ShaderAst::DeclareFunctionStatement& node, bool forward = false); void AppendFunctionDeclaration(const ShaderAst::DeclareFunctionStatement& node, bool forward = false);
void AppendField(std::size_t structIndex, const ShaderAst::ExpressionPtr* memberIndices, std::size_t remainingMembers);
void AppendHeader(); void AppendHeader();
void AppendLine(const std::string& txt = {}); void AppendLine(const std::string& txt = {});
template<typename... Args> void AppendLine(Args&&... params); template<typename... Args> void AppendLine(Args&&... params);
@ -84,6 +83,7 @@ namespace Nz
void Visit(ShaderAst::ExpressionPtr& expr, bool encloseIfRequired = false); void Visit(ShaderAst::ExpressionPtr& expr, bool encloseIfRequired = false);
void Visit(ShaderAst::AccessIdentifierExpression& node) override;
void Visit(ShaderAst::AccessIndexExpression& node) override; void Visit(ShaderAst::AccessIndexExpression& node) override;
void Visit(ShaderAst::AssignExpression& node) override; void Visit(ShaderAst::AssignExpression& node) override;
void Visit(ShaderAst::BinaryExpression& node) override; void Visit(ShaderAst::BinaryExpression& node) override;

View File

@ -45,6 +45,7 @@ namespace Nz
struct LayoutAttribute; struct LayoutAttribute;
struct LocationAttribute; struct LocationAttribute;
struct SetAttribute; struct SetAttribute;
struct UnrollAttribute;
void Append(const ShaderAst::ArrayType& type); void Append(const ShaderAst::ArrayType& type);
void Append(const ShaderAst::ExpressionType& type); void Append(const ShaderAst::ExpressionType& type);
@ -68,9 +69,9 @@ namespace Nz
void AppendAttribute(EntryAttribute entry); void AppendAttribute(EntryAttribute entry);
void AppendAttribute(LayoutAttribute layout); void AppendAttribute(LayoutAttribute layout);
void AppendAttribute(LocationAttribute location); void AppendAttribute(LocationAttribute location);
void AppendAttribute(SetAttribute location); void AppendAttribute(SetAttribute set);
void AppendAttribute(UnrollAttribute unroll);
void AppendCommentSection(const std::string& section); void AppendCommentSection(const std::string& section);
void AppendField(std::size_t structIndex, const ShaderAst::ExpressionPtr* memberIndices, std::size_t remainingMembers);
void AppendHeader(); void AppendHeader();
void AppendLine(const std::string& txt = {}); void AppendLine(const std::string& txt = {});
template<typename... Args> void AppendLine(Args&&... params); template<typename... Args> void AppendLine(Args&&... params);
@ -85,6 +86,7 @@ namespace Nz
void Visit(ShaderAst::ExpressionPtr& expr, bool encloseIfRequired = false); void Visit(ShaderAst::ExpressionPtr& expr, bool encloseIfRequired = false);
void Visit(ShaderAst::AccessIdentifierExpression& node) override;
void Visit(ShaderAst::AccessIndexExpression& node) override; void Visit(ShaderAst::AccessIndexExpression& node) override;
void Visit(ShaderAst::AssignExpression& node) override; void Visit(ShaderAst::AssignExpression& node) override;
void Visit(ShaderAst::BinaryExpression& node) override; void Visit(ShaderAst::BinaryExpression& node) override;
@ -107,6 +109,7 @@ namespace Nz
void Visit(ShaderAst::DeclareVariableStatement& node) override; void Visit(ShaderAst::DeclareVariableStatement& node) override;
void Visit(ShaderAst::DiscardStatement& node) override; void Visit(ShaderAst::DiscardStatement& node) override;
void Visit(ShaderAst::ExpressionStatement& node) override; void Visit(ShaderAst::ExpressionStatement& node) override;
void Visit(ShaderAst::ForStatement& node) override;
void Visit(ShaderAst::ForEachStatement& node) override; void Visit(ShaderAst::ForEachStatement& node) override;
void Visit(ShaderAst::MultiStatement& node) override; void Visit(ShaderAst::MultiStatement& node) override;
void Visit(ShaderAst::NoOpStatement& node) override; void Visit(ShaderAst::NoOpStatement& node) override;

View File

@ -108,6 +108,12 @@ namespace Nz::ShaderBuilder
inline std::unique_ptr<ShaderAst::ExpressionStatement> operator()(ShaderAst::ExpressionPtr expression) const; inline std::unique_ptr<ShaderAst::ExpressionStatement> operator()(ShaderAst::ExpressionPtr expression) const;
}; };
struct For
{
inline std::unique_ptr<ShaderAst::ForStatement> operator()(std::string varName, ShaderAst::ExpressionPtr fromExpression, ShaderAst::ExpressionPtr toExpression, ShaderAst::StatementPtr statement) const;
inline std::unique_ptr<ShaderAst::ForStatement> operator()(std::string varName, ShaderAst::ExpressionPtr fromExpression, ShaderAst::ExpressionPtr toExpression, ShaderAst::ExpressionPtr stepExpression, ShaderAst::StatementPtr statement) const;
};
struct ForEach struct ForEach
{ {
inline std::unique_ptr<ShaderAst::ForEachStatement> operator()(std::string varName, ShaderAst::ExpressionPtr expression, ShaderAst::StatementPtr statement) const; inline std::unique_ptr<ShaderAst::ForEachStatement> operator()(std::string varName, ShaderAst::ExpressionPtr expression, ShaderAst::StatementPtr statement) const;
@ -179,6 +185,7 @@ namespace Nz::ShaderBuilder
constexpr Impl::DeclareVariable DeclareVariable; constexpr Impl::DeclareVariable DeclareVariable;
constexpr Impl::ExpressionStatement ExpressionStatement; constexpr Impl::ExpressionStatement ExpressionStatement;
constexpr Impl::NoParam<ShaderAst::DiscardStatement> Discard; constexpr Impl::NoParam<ShaderAst::DiscardStatement> Discard;
constexpr Impl::For For;
constexpr Impl::ForEach ForEach; constexpr Impl::ForEach ForEach;
constexpr Impl::Identifier Identifier; constexpr Impl::Identifier Identifier;
constexpr Impl::Intrinsic Intrinsic; constexpr Impl::Intrinsic Intrinsic;

View File

@ -269,6 +269,29 @@ namespace Nz::ShaderBuilder
return expressionStatementNode; return expressionStatementNode;
} }
inline std::unique_ptr<ShaderAst::ForStatement> Nz::ShaderBuilder::Impl::For::operator()(std::string varName, ShaderAst::ExpressionPtr fromExpression, ShaderAst::ExpressionPtr toExpression, ShaderAst::StatementPtr statement) const
{
auto forNode = std::make_unique<ShaderAst::ForStatement>();
forNode->fromExpr = std::move(fromExpression);
forNode->statement = std::move(statement);
forNode->toExpr = std::move(toExpression);
forNode->varName = std::move(varName);
return forNode;
}
inline std::unique_ptr<ShaderAst::ForStatement> Nz::ShaderBuilder::Impl::For::operator()(std::string varName, ShaderAst::ExpressionPtr fromExpression, ShaderAst::ExpressionPtr toExpression, ShaderAst::ExpressionPtr stepExpression, ShaderAst::StatementPtr statement) const
{
auto forNode = std::make_unique<ShaderAst::ForStatement>();
forNode->fromExpr = std::move(fromExpression);
forNode->statement = std::move(statement);
forNode->stepExpr = std::move(stepExpression);
forNode->toExpr = std::move(toExpression);
forNode->varName = std::move(varName);
return forNode;
}
std::unique_ptr<ShaderAst::ForEachStatement> Impl::ForEach::operator()(std::string varName, ShaderAst::ExpressionPtr expression, ShaderAst::StatementPtr statement) const std::unique_ptr<ShaderAst::ForEachStatement> Impl::ForEach::operator()(std::string varName, ShaderAst::ExpressionPtr expression, ShaderAst::StatementPtr statement) const
{ {
auto forEachNode = std::make_unique<ShaderAst::ForEachStatement>(); auto forEachNode = std::make_unique<ShaderAst::ForEachStatement>();

View File

@ -12,6 +12,7 @@
#define NAZARA_SHADERLANG_TOKEN_LAST(X) NAZARA_SHADERLANG_TOKEN(X) #define NAZARA_SHADERLANG_TOKEN_LAST(X) NAZARA_SHADERLANG_TOKEN(X)
#endif #endif
NAZARA_SHADERLANG_TOKEN(Arrow)
NAZARA_SHADERLANG_TOKEN(Assign) NAZARA_SHADERLANG_TOKEN(Assign)
NAZARA_SHADERLANG_TOKEN(BoolFalse) NAZARA_SHADERLANG_TOKEN(BoolFalse)
NAZARA_SHADERLANG_TOKEN(BoolTrue) NAZARA_SHADERLANG_TOKEN(BoolTrue)
@ -33,7 +34,6 @@ NAZARA_SHADERLANG_TOKEN(External)
NAZARA_SHADERLANG_TOKEN(FloatingPointValue) NAZARA_SHADERLANG_TOKEN(FloatingPointValue)
NAZARA_SHADERLANG_TOKEN(For) NAZARA_SHADERLANG_TOKEN(For)
NAZARA_SHADERLANG_TOKEN(FunctionDeclaration) NAZARA_SHADERLANG_TOKEN(FunctionDeclaration)
NAZARA_SHADERLANG_TOKEN(FunctionReturn)
NAZARA_SHADERLANG_TOKEN(GreaterThan) NAZARA_SHADERLANG_TOKEN(GreaterThan)
NAZARA_SHADERLANG_TOKEN(GreaterThanEqual) NAZARA_SHADERLANG_TOKEN(GreaterThanEqual)
NAZARA_SHADERLANG_TOKEN(IntegerValue) NAZARA_SHADERLANG_TOKEN(IntegerValue)
@ -59,8 +59,8 @@ NAZARA_SHADERLANG_TOKEN(OpenCurlyBracket)
NAZARA_SHADERLANG_TOKEN(OpenSquareBracket) NAZARA_SHADERLANG_TOKEN(OpenSquareBracket)
NAZARA_SHADERLANG_TOKEN(OpenParenthesis) NAZARA_SHADERLANG_TOKEN(OpenParenthesis)
NAZARA_SHADERLANG_TOKEN(Option) NAZARA_SHADERLANG_TOKEN(Option)
NAZARA_SHADERLANG_TOKEN(Semicolon)
NAZARA_SHADERLANG_TOKEN(Return) NAZARA_SHADERLANG_TOKEN(Return)
NAZARA_SHADERLANG_TOKEN(Semicolon)
NAZARA_SHADERLANG_TOKEN(Struct) NAZARA_SHADERLANG_TOKEN(Struct)
NAZARA_SHADERLANG_TOKEN(While) NAZARA_SHADERLANG_TOKEN(While)

View File

@ -170,12 +170,26 @@ namespace Nz::ShaderAst
return clone; return clone;
} }
StatementPtr AstCloner::Clone(ForStatement& node)
{
auto clone = std::make_unique<ForStatement>();
clone->fromExpr = CloneExpression(node.fromExpr);
clone->stepExpr = CloneExpression(node.stepExpr);
clone->toExpr = CloneExpression(node.toExpr);
clone->statement = CloneStatement(node.statement);
clone->unroll = Clone(node.unroll);
clone->varName = node.varName;
return clone;
}
StatementPtr AstCloner::Clone(ForEachStatement& node) StatementPtr AstCloner::Clone(ForEachStatement& node)
{ {
auto clone = std::make_unique<ForEachStatement>(); auto clone = std::make_unique<ForEachStatement>();
clone->expression = CloneExpression(node.expression); clone->expression = CloneExpression(node.expression);
clone->statement = CloneStatement(node.statement); clone->statement = CloneStatement(node.statement);
clone->unroll = Clone(node.unroll); clone->unroll = Clone(node.unroll);
clone->varName = node.varName;
return clone; return clone;
} }

View File

@ -161,6 +161,21 @@ namespace Nz::ShaderAst
node.expression->Visit(*this); node.expression->Visit(*this);
} }
void AstRecursiveVisitor::Visit(ForStatement& node)
{
if (node.fromExpr)
node.fromExpr->Visit(*this);
if (node.toExpr)
node.toExpr->Visit(*this);
if (node.stepExpr)
node.stepExpr->Visit(*this);
if (node.statement)
node.statement->Visit(*this);
}
void AstRecursiveVisitor::Visit(ForEachStatement& node) void AstRecursiveVisitor::Visit(ForEachStatement& node)
{ {
if (node.expression) if (node.expression)

View File

@ -301,6 +301,16 @@ namespace Nz::ShaderAst
Node(node.expression); Node(node.expression);
} }
void AstSerializerBase::Serialize(ForStatement& node)
{
Attribute(node.unroll);
Value(node.varName);
Node(node.fromExpr);
Node(node.toExpr);
Node(node.stepExpr);
Node(node.statement);
}
void AstSerializerBase::Serialize(ForEachStatement& node) void AstSerializerBase::Serialize(ForEachStatement& node)
{ {
Attribute(node.unroll); Attribute(node.unroll);

View File

@ -163,19 +163,6 @@ namespace Nz::ShaderAst
const ExpressionType& exprType = GetExpressionType(*indexedExpr); const ExpressionType& exprType = GetExpressionType(*indexedExpr);
if (IsStructType(exprType)) if (IsStructType(exprType))
{ {
// Transform to AccessIndexExpression
AccessIndexExpression* accessIndexPtr;
if (indexedExpr->GetType() != NodeType::AccessIndexExpression)
{
std::unique_ptr<AccessIndexExpression> accessIndex = std::make_unique<AccessIndexExpression>();
accessIndex->expr = std::move(indexedExpr);
accessIndexPtr = accessIndex.get();
indexedExpr = std::move(accessIndex);
}
else
accessIndexPtr = static_cast<AccessIndexExpression*>(indexedExpr.get());
std::size_t structIndex = ResolveStruct(exprType); std::size_t structIndex = ResolveStruct(exprType);
assert(structIndex < m_context->structs.size()); assert(structIndex < m_context->structs.size());
const StructDescription* s = m_context->structs[structIndex]; const StructDescription* s = m_context->structs[structIndex];
@ -200,9 +187,43 @@ namespace Nz::ShaderAst
if (!fieldPtr) if (!fieldPtr)
throw AstError{ "unknown field " + identifier }; throw AstError{ "unknown field " + identifier };
if (m_context->options.useIdentifierAccessesForStructs)
{
// Use a AccessIdentifierExpression
AccessIdentifierExpression* accessIdentifierPtr;
if (indexedExpr->GetType() != NodeType::AccessIdentifierExpression)
{
std::unique_ptr<AccessIdentifierExpression> accessIndex = std::make_unique<AccessIdentifierExpression>();
accessIndex->expr = std::move(indexedExpr);
accessIdentifierPtr = accessIndex.get();
indexedExpr = std::move(accessIndex);
}
else
accessIdentifierPtr = static_cast<AccessIdentifierExpression*>(indexedExpr.get());
accessIdentifierPtr->identifiers.push_back(s->members[fieldIndex].name);
accessIdentifierPtr->cachedExpressionType = ResolveType(fieldPtr->type);
}
else
{
// Transform to AccessIndexExpression
AccessIndexExpression* accessIndexPtr;
if (indexedExpr->GetType() != NodeType::AccessIndexExpression)
{
std::unique_ptr<AccessIndexExpression> accessIndex = std::make_unique<AccessIndexExpression>();
accessIndex->expr = std::move(indexedExpr);
accessIndexPtr = accessIndex.get();
indexedExpr = std::move(accessIndex);
}
else
accessIndexPtr = static_cast<AccessIndexExpression*>(indexedExpr.get());
accessIndexPtr->indices.push_back(ShaderBuilder::Constant(fieldIndex)); accessIndexPtr->indices.push_back(ShaderBuilder::Constant(fieldIndex));
accessIndexPtr->cachedExpressionType = ResolveType(fieldPtr->type); accessIndexPtr->cachedExpressionType = ResolveType(fieldPtr->type);
} }
}
else if (IsPrimitiveType(exprType) || IsVectorType(exprType)) else if (IsPrimitiveType(exprType) || IsVectorType(exprType))
{ {
// Swizzle expression // Swizzle expression
@ -269,6 +290,8 @@ namespace Nz::ShaderAst
auto clone = static_unique_pointer_cast<AccessIndexExpression>(AstCloner::Clone(node)); auto clone = static_unique_pointer_cast<AccessIndexExpression>(AstCloner::Clone(node));
Validate(*clone); Validate(*clone);
// TODO: Handle AccessIndex on structs with m_context->options.useIdentifierAccessesForStructs
return clone; return clone;
} }
@ -829,9 +852,180 @@ namespace Nz::ShaderAst
return AstCloner::Clone(node); return AstCloner::Clone(node);
} }
StatementPtr SanitizeVisitor::Clone(ForStatement& node)
{
if (node.varName.empty())
throw AstError{ "numerical for variable name cannot be empty" };
auto fromExpr = CloneExpression(MandatoryExpr(node.fromExpr));
auto stepExpr = CloneExpression(node.stepExpr);
auto toExpr = CloneExpression(MandatoryExpr(node.toExpr));
MandatoryStatement(node.statement);
const ExpressionType& fromExprType = GetExpressionType(*fromExpr);
if (!IsPrimitiveType(fromExprType))
throw AstError{ "numerical for from expression must be an integer or unsigned integer" };
PrimitiveType fromType = std::get<PrimitiveType>(fromExprType);
if (fromType != PrimitiveType::Int32 && fromType != PrimitiveType::UInt32)
throw AstError{ "numerical for from expression must be an integer or unsigned integer" };
const ExpressionType& toExprType = GetExpressionType(*fromExpr);
if (toExprType != fromExprType)
throw AstError{ "numerical for to expression type must match from expression type" };
if (stepExpr)
{
const ExpressionType& stepExprType = GetExpressionType(*fromExpr);
if (stepExprType != fromExprType)
throw AstError{ "numerical for step expression type must match from expression type" };
}
AttributeValue<LoopUnroll> unrollValue;
if (node.unroll.HasValue())
{
unrollValue = ComputeAttributeValue(node.unroll);
if (unrollValue.GetResultingValue() == LoopUnroll::Always)
{
PushScope();
auto multi = std::make_unique<MultiStatement>();
auto Unroll = [&](auto dummy)
{
using T = std::decay_t<decltype(dummy)>;
T counter = std::get<T>(ComputeConstantValue(*fromExpr));
T to = std::get<T>(ComputeConstantValue(*toExpr));
T step = (stepExpr) ? std::get<T>(ComputeConstantValue(*stepExpr)) : T(1);
for (; counter < to; counter += step)
{
auto var = ShaderBuilder::DeclareVariable(node.varName, ShaderBuilder::Constant(counter));
Validate(*var);
multi->statements.emplace_back(std::move(var));
multi->statements.emplace_back(CloneStatement(node.statement));
}
};
switch (fromType)
{
case PrimitiveType::Int32:
Unroll(Int32{});
break;
case PrimitiveType::UInt32:
Unroll(UInt32{});
break;
default:
throw AstError{ "internal error" };
}
PopScope();
return multi;
}
}
if (m_context->options.reduceLoopsToWhile)
{
PushScope();
auto multi = std::make_unique<MultiStatement>();
// Counter variable
auto counterVariable = ShaderBuilder::DeclareVariable(node.varName, std::move(fromExpr));
Validate(*counterVariable);
std::size_t counterVarIndex = counterVariable->varIndex.value();
multi->statements.emplace_back(std::move(counterVariable));
// Target variable
auto targetVariable = ShaderBuilder::DeclareVariable("to", std::move(toExpr));
Validate(*targetVariable);
std::size_t targetVarIndex = targetVariable->varIndex.value();
multi->statements.emplace_back(std::move(targetVariable));
// Step variable
std::optional<std::size_t> stepVarIndex;
if (stepExpr)
{
auto stepVariable = ShaderBuilder::DeclareVariable("step", std::move(stepExpr));
Validate(*stepVariable);
stepVarIndex = stepVariable->varIndex;
multi->statements.emplace_back(std::move(stepVariable));
}
// While
auto whileStatement = std::make_unique<WhileStatement>();
whileStatement->unroll = std::move(unrollValue);
// While condition
auto condition = ShaderBuilder::Binary(BinaryType::CompLt, ShaderBuilder::Variable(counterVarIndex, fromType), ShaderBuilder::Variable(targetVarIndex, fromType));
Validate(*condition);
whileStatement->condition = std::move(condition);
// While body
auto body = std::make_unique<MultiStatement>();
body->statements.reserve(2);
body->statements.emplace_back(CloneStatement(node.statement));
ExpressionPtr incrExpr;
if (stepVarIndex)
incrExpr = ShaderBuilder::Variable(*stepVarIndex, fromType);
else
incrExpr = (fromType == PrimitiveType::Int32) ? ShaderBuilder::Constant(1) : ShaderBuilder::Constant(1u);
auto incrCounter = ShaderBuilder::Assign(AssignType::CompoundAdd, ShaderBuilder::Variable(counterVarIndex, fromType), std::move(incrExpr));
Validate(*incrCounter);
body->statements.emplace_back(ShaderBuilder::ExpressionStatement(std::move(incrCounter)));
whileStatement->body = std::move(body);
multi->statements.emplace_back(std::move(whileStatement));
PopScope();
return multi;
}
else
{
auto clone = std::make_unique<ForStatement>();
clone->fromExpr = std::move(fromExpr);
clone->stepExpr = std::move(stepExpr);
clone->toExpr = std::move(toExpr);
clone->varName = node.varName;
clone->unroll = std::move(unrollValue);
PushScope();
{
clone->varIndex = RegisterVariable(node.varName, fromExprType);
clone->statement = CloneStatement(node.statement);
}
PopScope();
SanitizeIdentifier(clone->varName);
return clone;
}
}
StatementPtr SanitizeVisitor::Clone(ForEachStatement& node) StatementPtr SanitizeVisitor::Clone(ForEachStatement& node)
{ {
auto expr = CloneExpression(node.expression); auto expr = CloneExpression(MandatoryExpr(node.expression));
if (node.varName.empty())
throw AstError{ "for-each variable name cannot be empty"};
const ExpressionType& exprType = GetExpressionType(*expr); const ExpressionType& exprType = GetExpressionType(*expr);
ExpressionType innerType; ExpressionType innerType;
@ -849,6 +1043,8 @@ namespace Nz::ShaderAst
unrollValue = ComputeAttributeValue(node.unroll); unrollValue = ComputeAttributeValue(node.unroll);
if (unrollValue.GetResultingValue() == LoopUnroll::Always) if (unrollValue.GetResultingValue() == LoopUnroll::Always)
{ {
PushScope();
// Repeat code // Repeat code
auto multi = std::make_unique<MultiStatement>(); auto multi = std::make_unique<MultiStatement>();
if (IsArrayType(exprType)) if (IsArrayType(exprType))
@ -869,6 +1065,8 @@ namespace Nz::ShaderAst
} }
} }
PopScope();
return multi; return multi;
} }
} }
@ -943,7 +1141,7 @@ namespace Nz::ShaderAst
} }
PopScope(); PopScope();
SanitizeIdentifier(node.varName); SanitizeIdentifier(clone->varName);
return clone; return clone;
} }

View File

@ -390,41 +390,6 @@ namespace Nz
AppendLine((forward) ? ");" : ")"); AppendLine((forward) ? ");" : ")");
} }
void GlslWriter::AppendField(std::size_t structIndex, const ShaderAst::ExpressionPtr* memberIndices, std::size_t remainingMembers)
{
ShaderAst::StructDescription* structDesc = Retrieve(m_currentState->structs, structIndex);
assert((*memberIndices)->GetType() == ShaderAst::NodeType::ConstantValueExpression);
auto& constantValue = static_cast<ShaderAst::ConstantValueExpression&>(**memberIndices);
Int32 index = std::get<Int32>(constantValue.value);
assert(index >= 0);
auto it = structDesc->members.begin();
for (; it != structDesc->members.end(); ++it)
{
const auto& member = *it;
if (member.cond.HasValue() && !member.cond.GetResultingValue())
continue;
if (index == 0)
break;
index--;
}
assert(it != structDesc->members.end());
const auto& member = *it;
Append(".");
Append(member.name);
if (remainingMembers > 1)
{
assert(IsStructType(member.type));
AppendField(std::get<ShaderAst::StructType>(member.type).structIndex, memberIndices + 1, remainingMembers - 1);
}
}
void GlslWriter::AppendHeader() void GlslWriter::AppendHeader()
{ {
unsigned int glslVersion; unsigned int glslVersion;
@ -734,17 +699,24 @@ namespace Nz
Append(")"); Append(")");
} }
void GlslWriter::Visit(ShaderAst::AccessIdentifierExpression& node)
{
Visit(node.expr, true);
const ShaderAst::ExpressionType& exprType = GetExpressionType(*node.expr);
assert(IsStructType(exprType));
for (const std::string& identifier : node.identifiers)
Append(".", identifier);
}
void GlslWriter::Visit(ShaderAst::AccessIndexExpression& node) void GlslWriter::Visit(ShaderAst::AccessIndexExpression& node)
{ {
Visit(node.expr, true); Visit(node.expr, true);
const ShaderAst::ExpressionType& exprType = GetExpressionType(*node.expr); const ShaderAst::ExpressionType& exprType = GetExpressionType(*node.expr);
assert(!IsStructType(exprType));
// For structs, convert indices to field names
if (IsStructType(exprType))
AppendField(std::get<ShaderAst::StructType>(exprType).structIndex, node.indices.data(), node.indices.size());
else
{
// Array access // Array access
for (ShaderAst::ExpressionPtr& expr : node.indices) for (ShaderAst::ExpressionPtr& expr : node.indices)
{ {
@ -753,7 +725,6 @@ namespace Nz
Append("]"); Append("]");
} }
} }
}
void GlslWriter::Visit(ShaderAst::AssignExpression& node) void GlslWriter::Visit(ShaderAst::AssignExpression& node)
{ {

View File

@ -83,6 +83,13 @@ namespace Nz
inline bool HasValue() const { return setIndex.HasValue(); } inline bool HasValue() const { return setIndex.HasValue(); }
}; };
struct LangWriter::UnrollAttribute
{
const ShaderAst::AttributeValue<ShaderAst::LoopUnroll>& unroll;
inline bool HasValue() const { return unroll.HasValue(); }
};
struct LangWriter::State struct LangWriter::State
{ {
const States* states = nullptr; const States* states = nullptr;
@ -103,10 +110,7 @@ namespace Nz
m_currentState = nullptr; m_currentState = nullptr;
}); });
ShaderAst::SanitizeVisitor::Options options; ShaderAst::StatementPtr sanitizedAst = ShaderAst::Sanitize(shader);
options.removeOptionDeclaration = false;
ShaderAst::StatementPtr sanitizedAst = ShaderAst::Sanitize(shader, options);
AppendHeader(); AppendHeader();
@ -277,10 +281,14 @@ namespace Nz
if (!binding.HasValue()) if (!binding.HasValue())
return; return;
Append("binding(");
if (binding.bindingIndex.IsResultingValue()) if (binding.bindingIndex.IsResultingValue())
Append("binding(", binding.bindingIndex.GetResultingValue(), ")"); Append(binding.bindingIndex.GetResultingValue());
else else
binding.bindingIndex.GetExpression()->Visit(*this); binding.bindingIndex.GetExpression()->Visit(*this);
Append(")");
} }
void LangWriter::AppendAttribute(BuiltinAttribute builtin) void LangWriter::AppendAttribute(BuiltinAttribute builtin)
@ -288,25 +296,29 @@ namespace Nz
if (!builtin.HasValue()) if (!builtin.HasValue())
return; return;
Append("builtin(");
if (builtin.builtin.IsResultingValue()) if (builtin.builtin.IsResultingValue())
{ {
switch (builtin.builtin.GetResultingValue()) switch (builtin.builtin.GetResultingValue())
{ {
case ShaderAst::BuiltinEntry::FragCoord: case ShaderAst::BuiltinEntry::FragCoord:
Append("builtin(fragcoord)"); Append("fragcoord");
break; break;
case ShaderAst::BuiltinEntry::FragDepth: case ShaderAst::BuiltinEntry::FragDepth:
Append("builtin(fragdepth)"); Append("fragdepth");
break; break;
case ShaderAst::BuiltinEntry::VertexPosition: case ShaderAst::BuiltinEntry::VertexPosition:
Append("builtin(position)"); Append("position");
break; break;
} }
} }
else else
builtin.builtin.GetExpression()->Visit(*this); builtin.builtin.GetExpression()->Visit(*this);
Append(")");
} }
void LangWriter::AppendAttribute(DepthWriteAttribute depthWrite) void LangWriter::AppendAttribute(DepthWriteAttribute depthWrite)
@ -314,29 +326,33 @@ namespace Nz
if (!depthWrite.HasValue()) if (!depthWrite.HasValue())
return; return;
Append("depth_write(");
if (depthWrite.writeMode.IsResultingValue()) if (depthWrite.writeMode.IsResultingValue())
{ {
switch (depthWrite.writeMode.GetResultingValue()) switch (depthWrite.writeMode.GetResultingValue())
{ {
case ShaderAst::DepthWriteMode::Greater: case ShaderAst::DepthWriteMode::Greater:
Append("depth_write(greater)"); Append("greater");
break; break;
case ShaderAst::DepthWriteMode::Less: case ShaderAst::DepthWriteMode::Less:
Append("depth_write(less)"); Append("less");
break; break;
case ShaderAst::DepthWriteMode::Replace: case ShaderAst::DepthWriteMode::Replace:
Append("depth_write(replace)"); Append("replace");
break; break;
case ShaderAst::DepthWriteMode::Unchanged: case ShaderAst::DepthWriteMode::Unchanged:
Append("depth_write(unchanged)"); Append("unchanged");
break; break;
} }
} }
else else
depthWrite.writeMode.GetExpression()->Visit(*this); depthWrite.writeMode.GetExpression()->Visit(*this);
Append(")");
} }
void LangWriter::AppendAttribute(EarlyFragmentTestsAttribute earlyFragmentTests) void LangWriter::AppendAttribute(EarlyFragmentTestsAttribute earlyFragmentTests)
@ -344,15 +360,19 @@ namespace Nz
if (!earlyFragmentTests.HasValue()) if (!earlyFragmentTests.HasValue())
return; return;
Append("early_fragment_tests(");
if (earlyFragmentTests.earlyFragmentTests.IsResultingValue()) if (earlyFragmentTests.earlyFragmentTests.IsResultingValue())
{ {
if (earlyFragmentTests.earlyFragmentTests.GetResultingValue()) if (earlyFragmentTests.earlyFragmentTests.GetResultingValue())
Append("early_fragment_tests(true)"); Append("true");
else else
Append("early_fragment_tests(false)"); Append("false");
} }
else else
earlyFragmentTests.earlyFragmentTests.GetExpression()->Visit(*this); earlyFragmentTests.earlyFragmentTests.GetExpression()->Visit(*this);
Append(")");
} }
void LangWriter::AppendAttribute(EntryAttribute entry) void LangWriter::AppendAttribute(EntryAttribute entry)
@ -360,21 +380,25 @@ namespace Nz
if (!entry.HasValue()) if (!entry.HasValue())
return; return;
Append("entry(");
if (entry.stageType.IsResultingValue()) if (entry.stageType.IsResultingValue())
{ {
switch (entry.stageType.GetResultingValue()) switch (entry.stageType.GetResultingValue())
{ {
case ShaderStageType::Fragment: case ShaderStageType::Fragment:
Append("entry(frag)"); Append("frag");
break; break;
case ShaderStageType::Vertex: case ShaderStageType::Vertex:
Append("entry(vert)"); Append("vert");
break; break;
} }
} }
else else
entry.stageType.GetExpression()->Visit(*this); entry.stageType.GetExpression()->Visit(*this);
Append(")");
} }
void LangWriter::AppendAttribute(LayoutAttribute entry) void LangWriter::AppendAttribute(LayoutAttribute entry)
@ -382,17 +406,19 @@ namespace Nz
if (!entry.HasValue()) if (!entry.HasValue())
return; return;
Append("layout(");
if (entry.layout.IsResultingValue()) if (entry.layout.IsResultingValue())
{ {
switch (entry.layout.GetResultingValue()) switch (entry.layout.GetResultingValue())
{ {
case StructLayout::Std140: case StructLayout::Std140:
Append("layout(std140)"); Append("std140");
break; break;
} }
} }
else else
entry.layout.GetExpression()->Visit(*this); entry.layout.GetExpression()->Visit(*this);
Append(")");
} }
void LangWriter::AppendAttribute(LocationAttribute location) void LangWriter::AppendAttribute(LocationAttribute location)
@ -400,10 +426,14 @@ namespace Nz
if (!location.HasValue()) if (!location.HasValue())
return; return;
Append("location(");
if (location.locationIndex.IsResultingValue()) if (location.locationIndex.IsResultingValue())
Append("location(", location.locationIndex.GetResultingValue(), ")"); Append(location.locationIndex.GetResultingValue());
else else
location.locationIndex.GetExpression()->Visit(*this); location.locationIndex.GetExpression()->Visit(*this);
Append(")");
} }
void LangWriter::AppendAttribute(SetAttribute set) void LangWriter::AppendAttribute(SetAttribute set)
@ -411,10 +441,45 @@ namespace Nz
if (!set.HasValue()) if (!set.HasValue())
return; return;
Append("set(");
if (set.setIndex.IsResultingValue()) if (set.setIndex.IsResultingValue())
Append("set(", set.setIndex.GetResultingValue(), ")"); Append(set.setIndex.GetResultingValue());
else else
set.setIndex.GetExpression()->Visit(*this); set.setIndex.GetExpression()->Visit(*this);
Append(")");
}
void LangWriter::AppendAttribute(UnrollAttribute unroll)
{
if (!unroll.HasValue())
return;
Append("unroll(");
if (unroll.unroll.IsResultingValue())
{
switch (unroll.unroll.GetResultingValue())
{
case ShaderAst::LoopUnroll::Always:
Append("always");
break;
case ShaderAst::LoopUnroll::Hint:
Append("hint");
break;
case ShaderAst::LoopUnroll::Never:
Append("never");
break;
default:
break;
}
}
else
unroll.unroll.GetExpression()->Visit(*this);
} }
void LangWriter::AppendCommentSection(const std::string& section) void LangWriter::AppendCommentSection(const std::string& section)
@ -426,26 +491,6 @@ namespace Nz
AppendLine(); AppendLine();
} }
void LangWriter::AppendField(std::size_t structIndex, const ShaderAst::ExpressionPtr* memberIndices, std::size_t remainingMembers)
{
ShaderAst::StructDescription* structDesc = Retrieve(m_currentState->structs, structIndex);
assert((*memberIndices)->GetType() == ShaderAst::NodeType::ConstantValueExpression);
auto& constantValue = static_cast<ShaderAst::ConstantValueExpression&>(**memberIndices);
Int32 index = std::get<Int32>(constantValue.value);
const auto& member = structDesc->members[index];
Append(".");
Append(member.name);
if (remainingMembers > 1)
{
assert(IsStructType(member.type));
AppendField(std::get<ShaderAst::StructType>(member.type).structIndex, memberIndices + 1, remainingMembers - 1);
}
}
void LangWriter::AppendLine(const std::string& txt) void LangWriter::AppendLine(const std::string& txt)
{ {
NazaraAssert(m_currentState, "This function should only be called while processing an AST"); NazaraAssert(m_currentState, "This function should only be called while processing an AST");
@ -526,26 +571,32 @@ namespace Nz
Append(")"); Append(")");
} }
void LangWriter::Visit(ShaderAst::AccessIdentifierExpression& node)
{
Visit(node.expr, true);
const ShaderAst::ExpressionType& exprType = GetExpressionType(*node.expr);
assert(IsStructType(exprType));
for (const std::string& identifier : node.identifiers)
Append(".", identifier);
}
void LangWriter::Visit(ShaderAst::AccessIndexExpression& node) void LangWriter::Visit(ShaderAst::AccessIndexExpression& node)
{ {
Visit(node.expr, true); Visit(node.expr, true);
const ShaderAst::ExpressionType& exprType = GetExpressionType(*node.expr); const ShaderAst::ExpressionType& exprType = GetExpressionType(*node.expr);
assert(!IsStructType(exprType));
// For structs, convert indices to field names
if (IsStructType(exprType))
AppendField(std::get<ShaderAst::StructType>(exprType).structIndex, node.indices.data(), node.indices.size());
else
{
// Array access // Array access
for (ShaderAst::ExpressionPtr& expr : node.indices) for (ShaderAst::ExpressionPtr& expr : node.indices)
{ {
Append("["); Append("[");
Visit(expr); expr->Visit(*this);
Append("]"); Append("]");
} }
} }
}
void LangWriter::Visit(ShaderAst::AssignExpression& node) void LangWriter::Visit(ShaderAst::AssignExpression& node)
{ {
@ -826,11 +877,36 @@ namespace Nz
Append(";"); Append(";");
} }
void LangWriter::Visit(ShaderAst::ForStatement& node)
{
assert(node.varIndex);
RegisterVariable(*node.varIndex, node.varName);
AppendAttributes(true, UnrollAttribute{ node.unroll });
Append("for ", node.varName, " in ");
node.fromExpr->Visit(*this);
Append(" -> ");
node.toExpr->Visit(*this);
if (node.stepExpr)
{
Append(" : ");
node.stepExpr->Visit(*this);
}
AppendLine();
EnterScope();
node.statement->Visit(*this);
LeaveScope();
}
void LangWriter::Visit(ShaderAst::ForEachStatement& node) void LangWriter::Visit(ShaderAst::ForEachStatement& node)
{ {
assert(node.varIndex); assert(node.varIndex);
RegisterVariable(*node.varIndex, node.varName); RegisterVariable(*node.varIndex, node.varName);
AppendAttributes(true, UnrollAttribute{ node.unroll });
Append("for ", node.varName, " in "); Append("for ", node.varName, " in ");
node.expression->Visit(*this); node.expression->Visit(*this);
AppendLine(); AppendLine();

View File

@ -113,7 +113,7 @@ namespace Nz::ShaderLang
if (next == '>') if (next == '>')
{ {
currentPos++; currentPos++;
tokenType = TokenType::FunctionReturn; tokenType = TokenType::Arrow;
break; break;
} }
else if (next == '=') else if (next == '=')

View File

@ -614,16 +614,54 @@ namespace Nz::ShaderLang
ShaderAst::ExpressionPtr expr = ParseExpression(); ShaderAst::ExpressionPtr expr = ParseExpression();
if (Peek().type == TokenType::Arrow)
{
// Numerical for
Consume();
ShaderAst::ExpressionPtr toExpr = ParseExpression();
ShaderAst::ExpressionPtr stepExpr;
if (Peek().type == TokenType::Colon)
{
Consume();
stepExpr = ParseExpression();
}
ShaderAst::StatementPtr statement = ParseStatement(); ShaderAst::StatementPtr statement = ParseStatement();
auto forEach = ShaderBuilder::ForEach(std::move(varName), std::move(expr), std::move(statement)); auto forNode = ShaderBuilder::For(std::move(varName), std::move(expr), std::move(toExpr), std::move(stepExpr), std::move(statement));
// TODO: Deduplicate code
for (auto&& [attributeType, arg] : attributes) for (auto&& [attributeType, arg] : attributes)
{ {
switch (attributeType) switch (attributeType)
{ {
case ShaderAst::AttributeType::Unroll: case ShaderAst::AttributeType::Unroll:
HandleUniqueStringAttribute("unroll", s_unrollModes, forEach->unroll, std::move(arg), std::make_optional(ShaderAst::LoopUnroll::Always)); HandleUniqueStringAttribute("unroll", s_unrollModes, forNode->unroll, std::move(arg), std::make_optional(ShaderAst::LoopUnroll::Always));
break;
default:
throw AttributeError{ "unhandled attribute for numerical for" };
}
}
return forNode;
}
else
{
// For each
ShaderAst::StatementPtr statement = ParseStatement();
auto forEachNode = ShaderBuilder::ForEach(std::move(varName), std::move(expr), std::move(statement));
// TODO: Deduplicate code
for (auto&& [attributeType, arg] : attributes)
{
switch (attributeType)
{
case ShaderAst::AttributeType::Unroll:
HandleUniqueStringAttribute("unroll", s_unrollModes, forEachNode->unroll, std::move(arg), std::make_optional(ShaderAst::LoopUnroll::Always));
break; break;
default: default:
@ -631,7 +669,8 @@ namespace Nz::ShaderLang
} }
} }
return forEach; return forEachNode;
}
} }
std::vector<ShaderAst::StatementPtr> Parser::ParseFunctionBody() std::vector<ShaderAst::StatementPtr> Parser::ParseFunctionBody()
@ -668,7 +707,7 @@ namespace Nz::ShaderLang
Expect(Advance(), TokenType::ClosingParenthesis); Expect(Advance(), TokenType::ClosingParenthesis);
ShaderAst::ExpressionType returnType; ShaderAst::ExpressionType returnType;
if (Peek().type == TokenType::FunctionReturn) if (Peek().type == TokenType::Arrow)
{ {
Consume(); Consume();
returnType = ParseType(); returnType = ParseType();

View File

@ -491,6 +491,7 @@ namespace Nz
options.removeCompoundAssignments = true; options.removeCompoundAssignments = true;
options.removeOptionDeclaration = true; options.removeOptionDeclaration = true;
options.splitMultipleBranches = true; options.splitMultipleBranches = true;
options.useIdentifierAccessesForStructs = false;
sanitizedAst = ShaderAst::Sanitize(shader, options); sanitizedAst = ShaderAst::Sanitize(shader, options);
targetAst = sanitizedAst.get(); targetAst = sanitizedAst.get();

View File

@ -110,6 +110,63 @@ fn main()
} }
} }
WHEN("using [unroll] attribute on numerical for")
{
std::string_view sourceCode = R"(
const LightCount = 3;
[layout(std140)]
struct Light
{
color: vec4<f32>
}
[layout(std140)]
struct LightData
{
lights: [Light; LightCount]
}
external
{
[set(0), binding(0)] data: uniform<LightData>
}
[entry(frag)]
fn main()
{
let color = (0.0).xxxx;
[unroll]
for i in 0 -> 10 : 2
{
color += data.lights[i].color;
}
}
)";
Nz::ShaderAst::StatementPtr shader;
REQUIRE_NOTHROW(shader = Nz::ShaderLang::Parse(sourceCode));
ExpectOutput(*shader, {}, R"(
[entry(frag)]
fn main()
{
let color: vec4<f32> = (0.000000).xxxx;
let i: i32 = 0;
color += data.lights[i].color;
let i: i32 = 2;
color += data.lights[i].color;
let i: i32 = 4;
color += data.lights[i].color;
let i: i32 = 6;
color += data.lights[i].color;
let i: i32 = 8;
color += data.lights[i].color;
}
)");
}
WHEN("using [unroll] attribute on for-each") WHEN("using [unroll] attribute on for-each")
{ {
std::string_view sourceCode = R"( std::string_view sourceCode = R"(

View File

@ -91,6 +91,160 @@ OpReturn
OpFunctionEnd)"); OpFunctionEnd)");
} }
WHEN("using a for range")
{
std::string_view nzslSource = R"(
[entry(frag)]
fn main()
{
let x = 0;
for v in 0 -> 10
{
x += v;
}
}
)";
Nz::ShaderAst::StatementPtr shader = Nz::ShaderLang::Parse(nzslSource);
ExpectGLSL(*shader, R"(
void main()
{
int x = 0;
int v = 0;
int to = 10;
while (v < to)
{
x += v;
v += 1;
}
}
)");
ExpectNZSL(*shader, R"(
[entry(frag)]
fn main()
{
let x: i32 = 0;
for v in 0 -> 10
{
x += v;
}
}
)");
ExpectSpirV(*shader, R"(
OpFunction
OpLabel
OpVariable
OpVariable
OpVariable
OpStore
OpStore
OpStore
OpBranch
OpLabel
OpLoad
OpLoad
OpSLessThan
OpLoopMerge
OpBranchConditional
OpLabel
OpLoad
OpLoad
OpIAdd
OpStore
OpLoad
OpIAdd
OpStore
OpBranch
OpLabel
OpReturn
OpFunctionEnd)");
}
WHEN("using a for range with step")
{
std::string_view nzslSource = R"(
[entry(frag)]
fn main()
{
let x = 0;
for v in 0 -> 10 : 2
{
x += v;
}
}
)";
Nz::ShaderAst::StatementPtr shader = Nz::ShaderLang::Parse(nzslSource);
ExpectGLSL(*shader, R"(
void main()
{
int x = 0;
int v = 0;
int to = 10;
int step = 2;
while (v < to)
{
x += v;
v += step;
}
}
)");
ExpectNZSL(*shader, R"(
[entry(frag)]
fn main()
{
let x: i32 = 0;
for v in 0 -> 10 : 2
{
x += v;
}
}
)");
ExpectSpirV(*shader, R"(
OpFunction
OpLabel
OpVariable
OpVariable
OpVariable
OpVariable
OpStore
OpStore
OpStore
OpStore
OpBranch
OpLabel
OpLoad
OpLoad
OpSLessThan
OpLoopMerge
OpBranchConditional
OpLabel
OpLoad
OpLoad
OpIAdd
OpStore
OpLoad
OpLoad
OpIAdd
OpStore
OpBranch
OpLabel
OpReturn
OpFunctionEnd)");
}
WHEN("using a for-each") WHEN("using a for-each")
{ {
std::string_view nzslSource = R"( std::string_view nzslSource = R"(