Shader: Add support for for-each statements and improve arrays

This commit is contained in:
Jérôme Leclercq
2022-01-02 22:02:11 +01:00
parent aac6e38da2
commit 4fe44339c5
30 changed files with 712 additions and 93 deletions

View File

@@ -170,6 +170,16 @@ namespace Nz::ShaderAst
return clone;
}
StatementPtr AstCloner::Clone(ForEachStatement& node)
{
auto clone = std::make_unique<ForEachStatement>();
clone->isConst = node.isConst;
clone->expression = CloneExpression(node.expression);
clone->statement = CloneStatement(node.statement);
return clone;
}
StatementPtr AstCloner::Clone(MultiStatement& node)
{
auto clone = std::make_unique<MultiStatement>();

View File

@@ -85,7 +85,8 @@ namespace Nz::ShaderAst
void AstRecursiveVisitor::Visit(SwizzleExpression& node)
{
node.expression->Visit(*this);
if (node.expression)
node.expression->Visit(*this);
}
void AstRecursiveVisitor::Visit(VariableExpression& /*node*/)
@@ -95,7 +96,8 @@ namespace Nz::ShaderAst
void AstRecursiveVisitor::Visit(UnaryExpression& node)
{
node.expression->Visit(*this);
if (node.expression)
node.expression->Visit(*this);
}
void AstRecursiveVisitor::Visit(BranchStatement& node)
@@ -159,6 +161,15 @@ namespace Nz::ShaderAst
node.expression->Visit(*this);
}
void AstRecursiveVisitor::Visit(ForEachStatement& node)
{
if (node.expression)
node.expression->Visit(*this);
if (node.statement)
node.statement->Visit(*this);
}
void AstRecursiveVisitor::Visit(MultiStatement& node)
{
for (auto& statement : node.statements)

View File

@@ -301,6 +301,14 @@ namespace Nz::ShaderAst
Node(node.expression);
}
void AstSerializerBase::Serialize(ForEachStatement& node)
{
Value(node.isConst);
Value(node.varName);
Node(node.expression);
Node(node.statement);
}
void AstSerializerBase::Serialize(MultiStatement& node)
{
Container(node.statements);

View File

@@ -13,7 +13,7 @@ namespace Nz::ShaderAst
{
assert(array.containedType);
containedType = std::make_unique<ContainedType>(*array.containedType);
length = Clone(length);
length = Clone(array.length);
}
ArrayType& ArrayType::operator=(const ArrayType& array)
@@ -21,7 +21,7 @@ namespace Nz::ShaderAst
assert(array.containedType);
containedType = std::make_unique<ContainedType>(*array.containedType);
length = Clone(length);
length = Clone(array.length);
return *this;
}

View File

@@ -278,39 +278,7 @@ namespace Nz::ShaderAst
MandatoryExpr(node.right);
auto clone = static_unique_pointer_cast<AssignExpression>(AstCloner::Clone(node));
if (GetExpressionCategory(*clone->left) != ExpressionCategory::LValue)
throw AstError{ "Assignation is only possible with a l-value" };
std::optional<BinaryType> binaryType;
switch (clone->op)
{
case AssignType::Simple:
TypeMustMatch(clone->left, clone->right);
break;
case AssignType::CompoundAdd: binaryType = BinaryType::Add; break;
case AssignType::CompoundDivide: binaryType = BinaryType::Divide; break;
case AssignType::CompoundMultiply: binaryType = BinaryType::Multiply; break;
case AssignType::CompoundLogicalAnd: binaryType = BinaryType::LogicalAnd; break;
case AssignType::CompoundLogicalOr: binaryType = BinaryType::LogicalOr; break;
case AssignType::CompoundSubtract: binaryType = BinaryType::Subtract; break;
}
if (binaryType)
{
ExpressionType expressionType = ValidateBinaryOp(*binaryType, clone->left, clone->right);
TypeMustMatch(GetExpressionType(*clone->left), expressionType);
if (m_context->options.removeCompoundAssignments)
{
clone->op = AssignType::Simple;
clone->right = ShaderBuilder::Binary(*binaryType, AstCloner::Clone(*clone->left), std::move(clone->right));
clone->right->cachedExpressionType = std::move(expressionType);
}
}
clone->cachedExpressionType = GetExpressionType(*clone->left);
Validate(*clone);
return clone;
}
@@ -318,7 +286,7 @@ namespace Nz::ShaderAst
ExpressionPtr SanitizeVisitor::Clone(BinaryExpression& node)
{
auto clone = static_unique_pointer_cast<BinaryExpression>(AstCloner::Clone(node));
clone->cachedExpressionType = ValidateBinaryOp(clone->op, clone->left, clone->right);
Validate(*clone);
return clone;
}
@@ -861,6 +829,119 @@ namespace Nz::ShaderAst
return AstCloner::Clone(node);
}
StatementPtr SanitizeVisitor::Clone(ForEachStatement& node)
{
auto expr = CloneExpression(node.expression);
const ExpressionType& exprType = GetExpressionType(*expr);
ExpressionType innerType;
if (IsArrayType(exprType))
{
const ArrayType& arrayType = std::get<ArrayType>(exprType);
innerType = arrayType.containedType->type;
}
else
throw AstError{ "for-each is only supported on arrays and range expressions" };
if (node.isConst)
{
// Repeat code
auto multi = std::make_unique<MultiStatement>();
if (IsArrayType(exprType))
{
const ArrayType& arrayType = std::get<ArrayType>(exprType);
UInt32 length = arrayType.length.GetResultingValue();
for (UInt32 i = 0; i < length; ++i)
{
auto accessIndex = ShaderBuilder::AccessIndex(CloneExpression(expr), ShaderBuilder::Constant(i));
Validate(*accessIndex);
auto elementVariable = ShaderBuilder::DeclareVariable(node.varName, std::move(accessIndex));
Validate(*elementVariable);
multi->statements.emplace_back(std::move(elementVariable));
multi->statements.emplace_back(CloneStatement(node.statement));
}
}
return multi;
}
if (m_context->options.reduceLoopsToWhile)
{
PushScope();
auto multi = std::make_unique<MultiStatement>();
if (IsArrayType(exprType))
{
const ArrayType& arrayType = std::get<ArrayType>(exprType);
UInt32 length = arrayType.length.GetResultingValue();
multi->statements.reserve(2);
// Counter variable
auto counterVariable = ShaderBuilder::DeclareVariable("i", ShaderBuilder::Constant(0u));
Validate(*counterVariable);
std::size_t counterVarIndex = counterVariable->varIndex.value();
multi->statements.emplace_back(std::move(counterVariable));
auto whileStatement = std::make_unique<WhileStatement>();
// While condition
auto condition = ShaderBuilder::Binary(BinaryType::CompLt, ShaderBuilder::Variable(counterVarIndex, PrimitiveType::UInt32), ShaderBuilder::Constant(length));
Validate(*condition);
whileStatement->condition = std::move(condition);
// While body
auto body = std::make_unique<MultiStatement>();
body->statements.reserve(3);
auto accessIndex = ShaderBuilder::AccessIndex(std::move(expr), ShaderBuilder::Variable(counterVarIndex, PrimitiveType::UInt32));
Validate(*accessIndex);
auto elementVariable = ShaderBuilder::DeclareVariable(node.varName, std::move(accessIndex));
Validate(*elementVariable);
body->statements.emplace_back(std::move(elementVariable));
body->statements.emplace_back(CloneStatement(node.statement));
auto incrCounter = ShaderBuilder::Assign(AssignType::CompoundAdd, ShaderBuilder::Variable(counterVarIndex, PrimitiveType::UInt32), ShaderBuilder::Constant(1u));
Validate(*incrCounter);
body->statements.emplace_back(ShaderBuilder::ExpressionStatement(std::move(incrCounter)));
whileStatement->body = std::move(body);
multi->statements.emplace_back(std::move(whileStatement));
}
PopScope();
return multi;
}
else
{
auto clone = std::make_unique<ForEachStatement>();
clone->expression = std::move(expr);
clone->varName = node.varName;
PushScope();
{
clone->varIndex = RegisterVariable(node.varName, innerType);
clone->statement = CloneStatement(node.statement);
}
PopScope();
SanitizeIdentifier(node.varName);
return clone;
}
}
StatementPtr SanitizeVisitor::Clone(MultiStatement& node)
{
PushScope();
@@ -1206,7 +1287,6 @@ namespace Nz::ShaderAst
using T = std::decay_t<decltype(arg)>;
if constexpr (std::is_same_v<T, NoType> ||
std::is_same_v<T, ArrayType> ||
std::is_same_v<T, PrimitiveType> ||
std::is_same_v<T, MatrixType> ||
std::is_same_v<T, SamplerType> ||
@@ -1215,6 +1295,22 @@ namespace Nz::ShaderAst
{
return exprType;
}
else if constexpr (std::is_same_v<T, ArrayType>)
{
ArrayType resolvedArrayType;
if (arg.length.IsExpression())
{
resolvedArrayType.length = CloneExpression(arg.length.GetExpression());
ComputeAttributeValue(resolvedArrayType.length);
}
else if (arg.length.IsResultingValue())
resolvedArrayType.length = arg.length.GetResultingValue();
resolvedArrayType.containedType = std::make_unique<ContainedType>();
resolvedArrayType.containedType->type = ResolveType(arg.containedType->type);
return resolvedArrayType;
}
else if constexpr (std::is_same_v<T, IdentifierType>)
{
const Identifier* identifier = FindIdentifier(arg.name);
@@ -1262,8 +1358,12 @@ namespace Nz::ShaderAst
for (auto& index : node.indices)
{
const ShaderAst::ExpressionType& indexType = GetExpressionType(*index);
if (!IsPrimitiveType(indexType) || std::get<PrimitiveType>(indexType) != PrimitiveType::Int32)
throw AstError{ "AccessIndex expects Int32 indices" };
if (!IsPrimitiveType(indexType))
throw AstError{ "AccessIndex expects integer indices" };
PrimitiveType primitiveIndexType = std::get<PrimitiveType>(indexType);
if (primitiveIndexType != PrimitiveType::Int32 && primitiveIndexType != PrimitiveType::UInt32)
throw AstError{ "AccessIndex expects integer indices" };
}
ExpressionType exprType = GetExpressionType(*node.expr);
@@ -1272,8 +1372,8 @@ namespace Nz::ShaderAst
if (IsArrayType(exprType))
{
const ArrayType& arrayType = std::get<ArrayType>(exprType);
exprType = arrayType.containedType->type;
ExpressionType containedType = arrayType.containedType->type; //< Don't overwrite exprType directly since it contains arrayType
exprType = std::move(containedType);
}
else if (IsStructType(exprType))
{
@@ -1294,7 +1394,7 @@ namespace Nz::ShaderAst
else if (IsMatrixType(exprType))
{
// Matrix index (ex: mat[2])
const MatrixType& matrixType = std::get<MatrixType>(exprType);
MatrixType matrixType = std::get<MatrixType>(exprType);
//TODO: Handle row-major matrices
exprType = VectorType{ matrixType.rowCount, matrixType.type };
@@ -1302,7 +1402,7 @@ namespace Nz::ShaderAst
else if (IsVectorType(exprType))
{
// Swizzle expression with one component (ex: vec[2])
const VectorType& swizzledVec = std::get<VectorType>(exprType);
VectorType swizzledVec = std::get<VectorType>(exprType);
exprType = swizzledVec.type;
}
@@ -1313,6 +1413,47 @@ namespace Nz::ShaderAst
node.cachedExpressionType = std::move(exprType);
}
void SanitizeVisitor::Validate(AssignExpression& node)
{
if (GetExpressionCategory(*node.left) != ExpressionCategory::LValue)
throw AstError{ "Assignation is only possible with a l-value" };
std::optional<BinaryType> binaryType;
switch (node.op)
{
case AssignType::Simple:
TypeMustMatch(node.left, node.right);
break;
case AssignType::CompoundAdd: binaryType = BinaryType::Add; break;
case AssignType::CompoundDivide: binaryType = BinaryType::Divide; break;
case AssignType::CompoundMultiply: binaryType = BinaryType::Multiply; break;
case AssignType::CompoundLogicalAnd: binaryType = BinaryType::LogicalAnd; break;
case AssignType::CompoundLogicalOr: binaryType = BinaryType::LogicalOr; break;
case AssignType::CompoundSubtract: binaryType = BinaryType::Subtract; break;
}
if (binaryType)
{
ExpressionType expressionType = ValidateBinaryOp(*binaryType, node.left, node.right);
TypeMustMatch(GetExpressionType(*node.left), expressionType);
if (m_context->options.removeCompoundAssignments)
{
node.op = AssignType::Simple;
node.right = ShaderBuilder::Binary(*binaryType, AstCloner::Clone(*node.left), std::move(node.right));
node.right->cachedExpressionType = std::move(expressionType);
}
}
node.cachedExpressionType = GetExpressionType(*node.left);
}
void SanitizeVisitor::Validate(BinaryExpression& node)
{
node.cachedExpressionType = ValidateBinaryOp(node.op, node.left, node.right);
}
void SanitizeVisitor::Validate(CallFunctionExpression& node, const DeclareFunctionStatement* referenceDeclaration)
{
if (referenceDeclaration->entryStage.HasValue())

View File

@@ -207,28 +207,23 @@ namespace Nz
ShaderAst::SanitizeVisitor::Options options;
options.optionValues = std::move(optionValues);
options.makeVariableNameUnique = true;
options.reduceLoopsToWhile = true;
options.removeCompoundAssignments = false;
options.removeOptionDeclaration = true;
options.removeScalarSwizzling = true;
options.reservedIdentifiers = {
// All reserved GLSL keywords as of GLSL ES 3.2
"active", "asm", "atomic_uint", "attribute", "bool", "break", "buffer", "bvec2", "bvec3", "bvec4", "case", "cast", "centroid", "class", "coherent", "common", "const", "continue", "default", "discard", "dmat2", "dmat2x2", "dmat2x3", "dmat2x4", "dmat3", "dmat3x2", "dmat3x3", "dmat3x4", "dmat4", "dmat4x2", "dmat4x3", "dmat4x4", "do", "double", "dvec2", "dvec3", "dvec4", "else", "enum", "extern", "external", "false", "filter", "fixed", "flat", "float", "for", "fvec2", "fvec3", "fvec4", "goto", "half", "highp", "hvec2", "hvec3", "hvec4", "if", "iimage1D", "iimage1DArray", "iimage2D", "iimage2DArray", "iimage2DMS", "iimage2DMSArray", "iimage2DRect", "iimage3D", "iimageBuffer", "iimageCube", "iimageCubeArray", "image1D", "image1DArray", "image2D", "image2DArray", "image2DMS", "image2DMSArray", "image2DRect", "image3D", "imageBuffer", "imageCube", "imageCubeArray", "in", "inline", "inout", "input", "int", "interface", "invariant", "isampler1D", "isampler1DArray", "isampler2D", "isampler2DArray", "isampler2DMS", "isampler2DMSArray", "isampler2DRect", "isampler3D", "isamplerBuffer", "isamplerCube", "isamplerCubeArray", "isubpassInput", "isubpassInputMS", "itexture2D", "itexture2DArray", "itexture2DMS", "itexture2DMSArray", "itexture3D", "itextureBuffer", "itextureCube", "itextureCubeArray", "ivec2", "ivec3", "ivec4", "layout", "long", "lowp", "mat2", "mat2x2", "mat2x3", "mat2x4", "mat3", "mat3x2", "mat3x3", "mat3x4", "mat4", "mat4x2", "mat4x3", "mat4x4", "mediump", "namespace", "noinline", "noperspective", "out", "output", "partition", "patch", "precise", "precision", "public", "readonly", "resource", "restrict", "return", "sample", "sampler", "sampler1D", "sampler1DArray", "sampler1DArrayShadow", "sampler1DShadow", "sampler2D", "sampler2DArray", "sampler2DArrayShadow", "sampler2DMS", "sampler2DMSArray", "sampler2DRect", "sampler2DRectShadow", "sampler2DShadow", "sampler3D", "sampler3DRect", "samplerBuffer", "samplerCube", "samplerCubeArray", "samplerCubeArrayShadow", "samplerCubeShadow", "samplerShadow", "shared", "short", "sizeof", "smooth", "static", "struct", "subpassInput", "subpassInputMS", "subroutine", "superp", "switch", "template", "texture2D", "texture2DArray", "texture2DMS", "texture2DMSArray", "texture3D", "textureBuffer", "textureCube", "textureCubeArray", "this", "true", "typedef", "uimage1D", "uimage1DArray", "uimage2D", "uimage2DArray", "uimage2DMS", "uimage2DMSArray", "uimage2DRect", "uimage3D", "uimageBuffer", "uimageCube", "uimageCubeArray", "uint", "uniform", "union", "unsigned", "usampler1D", "usampler1DArray", "usampler2D", "usampler2DArray", "usampler2DMS", "usampler2DMSArray", "usampler2DRect", "usampler3D", "usamplerBuffer", "usamplerCube", "usamplerCubeArray", "using", "usubpassInput", "usubpassInputMS", "utexture2D", "utexture2DArray", "utexture2DMS", "utexture2DMSArray", "utexture3D", "utextureBuffer", "utextureCube", "utextureCubeArray", "uvec2", "uvec3", "uvec4", "varying", "vec2", "vec3", "vec4", "void", "volatile", "while", "writeonly"
// GLSL functions
"cross", "dot", "length", "max", "min", "pow", "texture"
"cross", "dot", "exp", "length", "max", "min", "pow", "texture"
};
return ShaderAst::Sanitize(ast, options, error);
}
void GlslWriter::Append(const ShaderAst::ArrayType& type)
void GlslWriter::Append(const ShaderAst::ArrayType& /*type*/)
{
Append(type.containedType->type, "[");
if (type.length.IsResultingValue())
Append(type.length.GetResultingValue());
else
type.length.GetExpression()->Visit(*this);
Append("]");
throw std::runtime_error("unexpected ArrayType");
}
void GlslWriter::Append(const ShaderAst::ExpressionType& type)
@@ -390,7 +385,7 @@ namespace Nz
first = false;
Append(parameter.type, " ", parameter.name);
AppendVariableDeclaration(parameter.type, parameter.name);
}
AppendLine((forward) ? ");" : ")");
}
@@ -538,6 +533,40 @@ namespace Nz
}
}
void GlslWriter::AppendVariableDeclaration(const ShaderAst::ExpressionType& varType, const std::string& varName)
{
if (ShaderAst::IsArrayType(varType))
{
std::vector<const ShaderAst::AttributeValue<UInt32>*> lengths;
const ShaderAst::ExpressionType* exprType = &varType;
while (ShaderAst::IsArrayType(*exprType))
{
const auto& arrayType = std::get<ShaderAst::ArrayType>(*exprType);
lengths.push_back(&arrayType.length);
exprType = &arrayType.containedType->type;
}
assert(!ShaderAst::IsArrayType(*exprType));
Append(*exprType, " ", varName);
for (const auto* lengthAttribute : lengths)
{
Append("[");
if (lengthAttribute->IsResultingValue())
Append(lengthAttribute->GetResultingValue());
else
lengthAttribute->GetExpression()->Visit(*this);
Append("]");
}
}
else
Append(varType, " ", varName);
}
void GlslWriter::EnterScope()
{
NazaraAssert(m_currentState, "This function should only be called while processing an AST");
@@ -632,13 +661,8 @@ namespace Nz
{
Append("layout(location = ");
Append(member.locationIndex.GetResultingValue());
Append(") ");
Append(keyword);
Append(" ");
Append(member.type);
Append(" ");
Append(targetPrefix);
Append(member.name);
Append(") ", keyword, " ");
AppendVariableDeclaration(member.type, targetPrefix + member.name);
AppendLine(";");
fields.push_back({
@@ -824,8 +848,10 @@ namespace Nz
throw std::runtime_error("invalid type (value expected)");
else if constexpr (std::is_same_v<T, bool>)
Append((arg) ? "true" : "false");
else if constexpr (std::is_same_v<T, float> || std::is_same_v<T, Int32> || std::is_same_v<T, UInt32>)
else if constexpr (std::is_same_v<T, float> || std::is_same_v<T, Int32>)
Append(std::to_string(arg));
else if constexpr (std::is_same_v<T, UInt32>)
Append(std::to_string(arg), "u");
else if constexpr (std::is_same_v<T, Vector2f> || std::is_same_v<T, Vector2i32>)
Append("vec2(" + std::to_string(arg.x) + ", " + std::to_string(arg.y) + ")");
else if constexpr (std::is_same_v<T, Vector3f> || std::is_same_v<T, Vector3i32>)
@@ -1033,19 +1059,18 @@ namespace Nz
first = false;
Append(member.type);
Append(" ");
Append(member.name);
AppendVariableDeclaration(member.type, member.name);
Append(";");
}
}
LeaveScope(false);
Append(" ");
Append(externalVar.name);
}
else
Append(externalVar.type);
AppendVariableDeclaration(externalVar.type, externalVar.name);
Append(" ");
Append(externalVar.name);
AppendLine(";");
if (IsUniformType(externalVar.type))
@@ -1127,9 +1152,7 @@ namespace Nz
first = false;
Append(member.type);
Append(" ");
Append(member.name);
AppendVariableDeclaration(member.type, member.name);
Append(";");
}
}
@@ -1142,7 +1165,7 @@ namespace Nz
assert(node.varIndex);
RegisterVariable(*node.varIndex, node.varName);
Append(node.varType, " ", node.varName);
AppendVariableDeclaration(node.varType, node.varName);
if (node.initialExpression)
{
Append(" = ");

View File

@@ -170,7 +170,7 @@ namespace Nz
case ShaderAst::PrimitiveType::Boolean: return Append("bool");
case ShaderAst::PrimitiveType::Float32: return Append("f32");
case ShaderAst::PrimitiveType::Int32: return Append("i32");
case ShaderAst::PrimitiveType::UInt32: return Append("ui32");
case ShaderAst::PrimitiveType::UInt32: return Append("u32");
}
}
@@ -185,7 +185,7 @@ namespace Nz
case ImageType::E2D: Append("2D"); break;
case ImageType::E2D_Array: Append("2DArray"); break;
case ImageType::E3D: Append("3D"); break;
case ImageType::Cubemap: Append("Cube"); break;
case ImageType::Cubemap: Append("Cube"); break;
}
Append("<", samplerType.sampledType, ">");
@@ -653,6 +653,21 @@ namespace Nz
node.statement->Visit(*this);
}
void LangWriter::Visit(ShaderAst::DeclareConstStatement& node)
{
assert(node.constIndex);
RegisterConstant(*node.constIndex, node.name);
Append("const ", node.name, ": ", node.type);
if (node.expression)
{
Append(" = ");
node.expression->Visit(*this);
}
Append(";");
}
void LangWriter::Visit(ShaderAst::ConstantValueExpression& node)
{
std::visit([&](auto&& arg)
@@ -811,6 +826,20 @@ namespace Nz
Append(";");
}
void LangWriter::Visit(ShaderAst::ForEachStatement& node)
{
assert(node.varIndex);
RegisterVariable(*node.varIndex, node.varName);
Append("for ", node.varName, " in ");
node.expression->Visit(*this);
AppendLine();
EnterScope();
node.statement->Visit(*this);
LeaveScope();
}
void LangWriter::Visit(ShaderAst::IntrinsicExpression& node)
{
bool method = false;

View File

@@ -47,7 +47,9 @@ namespace Nz::ShaderLang
{ "external", TokenType::External },
{ "false", TokenType::BoolFalse },
{ "fn", TokenType::FunctionDeclaration },
{ "for", TokenType::For },
{ "if", TokenType::If },
{ "in", TokenType::In },
{ "let", TokenType::Let },
{ "option", TokenType::Option },
{ "return", TokenType::Return },

View File

@@ -3,6 +3,7 @@
// For conditions of distribution and use, see copyright notice in Config.hpp
#include <Nazara/Shader/ShaderLangParser.hpp>
#include <Nazara/Core/Algorithm.hpp>
#include <Nazara/Core/File.hpp>
#include <Nazara/Shader/ShaderBuilder.hpp>
#include <cassert>
@@ -472,6 +473,14 @@ namespace Nz::ShaderLang
switch (Peek().type)
{
case TokenType::For:
{
auto forEach = ParseForDeclaration();
SafeCast<ShaderAst::ForEachStatement&>(*forEach).isConst = true;
return forEach;
}
case TokenType::Identifier:
{
std::string constName;
@@ -487,7 +496,7 @@ namespace Nz::ShaderLang
case TokenType::If:
{
auto branch = ParseBranchStatement();
static_cast<ShaderAst::BranchStatement&>(*branch).isConst = true;
SafeCast<ShaderAst::BranchStatement&>(*branch).isConst = true;
return branch;
}
@@ -589,6 +598,21 @@ namespace Nz::ShaderLang
return externalStatement;
}
ShaderAst::StatementPtr Parser::ParseForDeclaration()
{
Expect(Advance(), TokenType::For);
std::string varName = ParseIdentifierAsName();
Expect(Advance(), TokenType::In);
ShaderAst::ExpressionPtr expr = ParseExpression();
ShaderAst::StatementPtr statement = ParseStatement();
return ShaderBuilder::ForEach(std::move(varName), std::move(expr), std::move(statement));
}
std::vector<ShaderAst::StatementPtr> Parser::ParseFunctionBody()
{
return ParseStatementList();
@@ -734,6 +758,10 @@ namespace Nz::ShaderLang
statement = ParseDiscardStatement();
break;
case TokenType::For:
statement = ParseForDeclaration();
break;
case TokenType::Let:
statement = ParseVariableDeclaration();
break;

View File

@@ -38,7 +38,7 @@ namespace Nz
bool Compare(const Array& lhs, const Array& rhs) const
{
return lhs.length == rhs.length && Compare(lhs.elementType, rhs.elementType);
return Compare(lhs.length, rhs.length) && Compare(lhs.elementType, rhs.elementType) && lhs.stride == rhs.stride;
}
bool Compare(const Bool& /*lhs*/, const Bool& /*rhs*/) const
@@ -237,6 +237,8 @@ namespace Nz
{
assert(array.elementType);
cache.Register(*array.elementType);
assert(array.length);
cache.Register(*array.length);
}
void Register(const Bool&) {}
@@ -416,6 +418,7 @@ namespace Nz
tsl::ordered_map<Structure, FieldOffsets /*fieldOffsets*/, AnyHasher, Eq> structureSizes;
StructCallback structCallback;
UInt32& nextResultId;
bool isInBlockStruct = false;
};
SpirvConstantCache::SpirvConstantCache(UInt32& resultId)
@@ -493,26 +496,59 @@ namespace Nz
auto SpirvConstantCache::BuildPointerType(const ShaderAst::ExpressionType& type, SpirvStorageClass storageClass) const -> TypePtr
{
return std::make_shared<Type>(Pointer{
bool wasInblockStruct = m_internal->isInBlockStruct;
if (storageClass == SpirvStorageClass::Uniform)
m_internal->isInBlockStruct = true;
auto typePtr = std::make_shared<Type>(Pointer{
BuildType(type),
storageClass
});
m_internal->isInBlockStruct = wasInblockStruct;
return typePtr;
}
auto SpirvConstantCache::BuildPointerType(const TypePtr& type, SpirvStorageClass storageClass) const -> TypePtr
{
return std::make_shared<Type>(Pointer{
bool wasInblockStruct = m_internal->isInBlockStruct;
if (storageClass == SpirvStorageClass::Uniform)
m_internal->isInBlockStruct = true;
auto typePtr = std::make_shared<Type>(Pointer{
type,
storageClass
});
m_internal->isInBlockStruct = wasInblockStruct;
return typePtr;
}
auto SpirvConstantCache::BuildType(const ShaderAst::ArrayType& type) const -> TypePtr
{
return std::make_shared<Type>(Array{
BuildType(type.containedType->type),
BuildConstant(type.length.GetResultingValue()),
(m_internal->isInBlockStruct) ? std::make_optional<UInt32>(16) : std::nullopt
});
}
auto SpirvConstantCache::BuildPointerType(const ShaderAst::PrimitiveType& type, SpirvStorageClass storageClass) const -> TypePtr
{
return std::make_shared<Type>(Pointer{
bool wasInblockStruct = m_internal->isInBlockStruct;
if (storageClass == SpirvStorageClass::Uniform)
m_internal->isInBlockStruct = true;
auto typePtr = std::make_shared<Type>(Pointer{
BuildType(type),
storageClass
});
m_internal->isInBlockStruct = wasInblockStruct;
return typePtr;
}
auto SpirvConstantCache::BuildType(const ShaderAst::ExpressionType& type) const -> TypePtr
@@ -614,6 +650,10 @@ namespace Nz
sType.name = structDesc.name;
sType.decorations = std::move(decorations);
bool wasInBlock = m_internal->isInBlockStruct;
if (!wasInBlock)
m_internal->isInBlockStruct = std::find(sType.decorations.begin(), sType.decorations.end(), SpirvDecoration::Block) != sType.decorations.end();
for (const auto& member : structDesc.members)
{
if (member.cond.HasValue() && !member.cond.GetResultingValue())
@@ -624,6 +664,8 @@ namespace Nz
sMembers.type = BuildType(member.type);
}
m_internal->isInBlockStruct = wasInBlock;
return std::make_shared<Type>(std::move(sType));
}
@@ -814,7 +856,11 @@ namespace Nz
using T = std::decay_t<decltype(arg)>;
if constexpr (std::is_same_v<T, Array>)
constants.Append(SpirvOp::OpTypeArray, resultId, GetId(*arg.elementType), arg.length);
{
constants.Append(SpirvOp::OpTypeArray, resultId, GetId(*arg.elementType), GetId(*arg.length));
if (arg.stride)
annotations.Append(SpirvOp::OpDecorate, resultId, SpirvDecoration::ArrayStride, *arg.stride);
}
else if constexpr (std::is_same_v<T, Bool>)
constants.Append(SpirvOp::OpTypeBool, resultId);
else if constexpr (std::is_same_v<T, Float>)
@@ -908,8 +954,23 @@ namespace Nz
if constexpr (std::is_same_v<T, Array>)
{
// TODO
throw std::runtime_error("todo");
assert(std::holds_alternative<ConstantScalar>(arg.length->constant));
const auto& scalar = std::get<ConstantScalar>(arg.length->constant);
assert(std::holds_alternative<UInt32>(scalar.value));
std::size_t length = std::get<UInt32>(scalar.value);
if (!std::holds_alternative<Float>(arg.elementType->type))
throw std::runtime_error("todo");
// FIXME: Virer cette implémentation du ghetto
const Float& fData = std::get<Float>(arg.elementType->type);
switch (fData.width)
{
case 32: return structOffsets.AddFieldArray(StructFieldType::Float1, length);
case 64: return structOffsets.AddFieldArray(StructFieldType::Double1, length);
default: throw std::runtime_error("unexpected float width " + std::to_string(fData.width));
}
}
else if constexpr (std::is_same_v<T, Bool>)
return structOffsets.AddField(StructFieldType::Bool1);

View File

@@ -487,7 +487,9 @@ namespace Nz
{
ShaderAst::SanitizeVisitor::Options options;
options.optionValues = states.optionValues;
options.reduceLoopsToWhile = true;
options.removeCompoundAssignments = true;
options.removeOptionDeclaration = true;
options.splitMultipleBranches = true;
sanitizedAst = ShaderAst::Sanitize(shader, options);