Shader/NZSL: Add support for array indexing

This commit is contained in:
Jérôme Leclercq 2021-06-01 16:22:41 +02:00
parent 0f3c0abb96
commit 4465e230af
17 changed files with 1139 additions and 976 deletions

View File

@ -70,7 +70,7 @@ struct VertOut
[entry(frag)]
fn main(input: FragIn) -> FragOut
{
let fragcoord = (input.fragcoord).xy * viewerData.invRenderTargetSize;
let fragcoord = input.fragcoord.xy * viewerData.invRenderTargetSize;
let normal = normalTexture.Sample(fragcoord).xyz * 2.0 - vec3<f32>(1.0, 1.0, 1.0);
let position = positionTexture.Sample(fragcoord).xyz;

View File

@ -70,7 +70,7 @@ namespace Nz::ShaderAst
void Visit(AstExpressionVisitor& visitor) override;
ExpressionPtr expr;
std::vector<std::string> memberIdentifiers;
std::vector<std::string> identifiers;
};
struct NAZARA_SHADER_API AccessIndexExpression : public Expression
@ -79,7 +79,7 @@ namespace Nz::ShaderAst
void Visit(AstExpressionVisitor& visitor) override;
ExpressionPtr expr;
std::vector<std::size_t> memberIndices;
std::vector<ExpressionPtr> indices;
};
struct NAZARA_SHADER_API AssignExpression : public Expression

View File

@ -43,11 +43,10 @@ namespace Nz::ShaderAst
struct FunctionData;
struct Identifier;
const ExpressionType& CheckField(const ExpressionType& structType, const std::string* memberIdentifier, std::size_t remainingMembers, std::size_t* structIndices);
using AstCloner::CloneExpression;
ExpressionPtr Clone(AccessIdentifierExpression& node) override;
ExpressionPtr Clone(AccessIndexExpression& node) override;
ExpressionPtr Clone(AssignExpression& node) override;
ExpressionPtr Clone(BinaryExpression& node) override;
ExpressionPtr Clone(CallFunctionExpression& node) override;
@ -101,6 +100,7 @@ namespace Nz::ShaderAst
void SanitizeIdentifier(std::string& identifier);
void Validate(AccessIndexExpression& node);
void Validate(CallFunctionExpression& node, const DeclareFunctionStatement* referenceDeclaration);
void Validate(IntrinsicExpression& node);

View File

@ -62,7 +62,7 @@ namespace Nz
template<typename T1, typename T2, typename... Args> void Append(const T1& firstParam, const T2& secondParam, Args&&... params);
void AppendCommentSection(const std::string& section);
void AppendFunctionDeclaration(const ShaderAst::DeclareFunctionStatement& node, bool forward = false);
void AppendField(std::size_t structIndex, const std::size_t* memberIndices, std::size_t remainingMembers);
void AppendField(std::size_t structIndex, const ShaderAst::ExpressionPtr* memberIndices, std::size_t remainingMembers);
void AppendHeader();
void AppendLine(const std::string& txt = {});
template<typename... Args> void AppendLine(Args&&... params);

View File

@ -65,7 +65,7 @@ namespace Nz
void AppendAttribute(LayoutAttribute layout);
void AppendAttribute(LocationAttribute location);
void AppendCommentSection(const std::string& section);
void AppendField(std::size_t structIndex, const std::size_t* memberIndices, std::size_t remainingMembers);
void AppendField(std::size_t structIndex, const ShaderAst::ExpressionPtr* memberIndices, std::size_t remainingMembers);
void AppendHeader();
void AppendLine(const std::string& txt = {});
template<typename... Args> void AppendLine(Args&&... params);

View File

@ -19,7 +19,8 @@ namespace Nz::ShaderBuilder
{
struct AccessIndex
{
inline std::unique_ptr<ShaderAst::AccessIndexExpression> operator()(ShaderAst::ExpressionPtr expr, std::vector<std::size_t> memberIndices) const;
inline std::unique_ptr<ShaderAst::AccessIndexExpression> operator()(ShaderAst::ExpressionPtr expr, const std::vector<Int32>& indexConstants) const;
inline std::unique_ptr<ShaderAst::AccessIndexExpression> operator()(ShaderAst::ExpressionPtr expr, std::vector<ShaderAst::ExpressionPtr> indexExpressions) const;
};
struct AccessMember

View File

@ -11,16 +11,28 @@ namespace Nz::ShaderBuilder
{
auto accessMemberNode = std::make_unique<ShaderAst::AccessIdentifierExpression>();
accessMemberNode->expr = std::move(expr);
accessMemberNode->memberIdentifiers = std::move(memberIdentifiers);
accessMemberNode->identifiers = std::move(memberIdentifiers);
return accessMemberNode;
}
inline std::unique_ptr<ShaderAst::AccessIndexExpression> Impl::AccessIndex::operator()(ShaderAst::ExpressionPtr expr, std::vector<std::size_t> memberIndices) const
inline std::unique_ptr<ShaderAst::AccessIndexExpression> Impl::AccessIndex::operator()(ShaderAst::ExpressionPtr expr, const std::vector<Int32>& indexConstants) const
{
auto accessMemberNode = std::make_unique<ShaderAst::AccessIndexExpression>();
accessMemberNode->expr = std::move(expr);
accessMemberNode->memberIndices = std::move(memberIndices);
accessMemberNode->indices.reserve(indexConstants.size());
for (Int32 index : indexConstants)
accessMemberNode->indices.push_back(ShaderBuilder::Constant(index));
return accessMemberNode;
}
inline std::unique_ptr<ShaderAst::AccessIndexExpression> Impl::AccessIndex::operator()(ShaderAst::ExpressionPtr expr, std::vector<ShaderAst::ExpressionPtr> indexExpressions) const
{
auto accessMemberNode = std::make_unique<ShaderAst::AccessIndexExpression>();
accessMemberNode->expr = std::move(expr);
accessMemberNode->indices = std::move(indexExpressions);
return accessMemberNode;
}

View File

@ -161,7 +161,7 @@ namespace Nz::ShaderAst
ExpressionPtr AstCloner::Clone(AccessIdentifierExpression& node)
{
auto clone = std::make_unique<AccessIdentifierExpression>();
clone->memberIdentifiers = node.memberIdentifiers;
clone->identifiers = node.identifiers;
clone->expr = CloneExpression(node.expr);
clone->cachedExpressionType = node.cachedExpressionType;
@ -172,9 +172,12 @@ namespace Nz::ShaderAst
ExpressionPtr AstCloner::Clone(AccessIndexExpression& node)
{
auto clone = std::make_unique<AccessIndexExpression>();
clone->memberIndices = node.memberIndices;
clone->expr = CloneExpression(node.expr);
clone->indices.reserve(node.indices.size());
for (auto& parameter : node.indices)
clone->indices.push_back(CloneExpression(parameter));
clone->cachedExpressionType = node.cachedExpressionType;
return clone;

View File

@ -15,6 +15,8 @@ namespace Nz::ShaderAst
void AstRecursiveVisitor::Visit(AccessIndexExpression& node)
{
node.expr->Visit(*this);
for (auto& index : node.indices)
index->Visit(*this);
}
void AstRecursiveVisitor::Visit(AssignExpression& node)

View File

@ -37,8 +37,8 @@ namespace Nz::ShaderAst
{
Node(node.expr);
Container(node.memberIdentifiers);
for (std::string& identifier : node.memberIdentifiers)
Container(node.identifiers);
for (std::string& identifier : node.identifiers)
Value(identifier);
}
@ -46,9 +46,9 @@ namespace Nz::ShaderAst
{
Node(node.expr);
Container(node.memberIndices);
for (std::size_t& identifier : node.memberIndices)
SizeT(identifier);
Container(node.indices);
for (auto& identifier : node.indices)
Node(identifier);
}
void AstSerializerBase::Serialize(AssignExpression& node)

View File

@ -95,110 +95,114 @@ namespace Nz::ShaderAst
return clone;
}
const ExpressionType& SanitizeVisitor::CheckField(const ExpressionType& structType, const std::string* memberIdentifier, std::size_t remainingMembers, std::size_t* structIndices)
{
std::size_t structIndex = ResolveStruct(structType);
*structIndices++ = structIndex;
assert(structIndex < m_structs.size());
const StructDescription& s = m_structs[structIndex];
auto memberIt = std::find_if(s.members.begin(), s.members.end(), [&](const auto& field) { return field.name == memberIdentifier[0]; });
if (memberIt == s.members.end())
throw AstError{ "unknown field " + memberIdentifier[0] };
const auto& member = *memberIt;
if (remainingMembers > 1)
return CheckField(member.type, memberIdentifier + 1, remainingMembers - 1, structIndices);
else
return member.type;
}
ExpressionPtr SanitizeVisitor::Clone(AccessIdentifierExpression& node)
{
auto structExpr = CloneExpression(MandatoryExpr(node.expr));
if (node.identifiers.empty())
throw AstError{ "AccessIdentifierExpression must have at least one identifier" };
const ExpressionType& exprType = GetExpressionType(*structExpr);
if (IsVectorType(exprType))
ExpressionPtr indexedExpr = CloneExpression(MandatoryExpr(node.expr));
for (std::size_t i = 0; i < node.identifiers.size(); ++i)
{
const VectorType& swizzledVec = std::get<VectorType>(exprType);
const std::string& identifier = node.identifiers[i];
// Swizzle expression
auto swizzle = std::make_unique<SwizzleExpression>();
swizzle->expression = std::move(structExpr);
// FIXME: Handle properly multiple identifiers (treat recursively)
if (node.memberIdentifiers.size() != 1)
throw AstError{ "invalid swizzle" };
const std::string& swizzleStr = node.memberIdentifiers.front();
if (swizzleStr.empty() || swizzleStr.size() > swizzle->components.size())
throw AstError{ "invalid swizzle" };
swizzle->componentCount = swizzleStr.size();
if (swizzle->componentCount > 1)
swizzle->cachedExpressionType = VectorType{ swizzle->componentCount, swizzledVec.type };
else
swizzle->cachedExpressionType = swizzledVec.type;
for (std::size_t i = 0; i < swizzle->componentCount; ++i)
const ExpressionType& exprType = GetExpressionType(*indexedExpr);
if (IsStructType(exprType))
{
switch (swizzleStr[i])
// Transform to AccessIndexExpression
AccessIndexExpression* accessIndexPtr;
if (indexedExpr->GetType() != NodeType::AccessIndexExpression)
{
case 'r':
case 'x':
case 's':
swizzle->components[i] = SwizzleComponent::First;
break;
std::unique_ptr<AccessIndexExpression> accessIndex = std::make_unique<AccessIndexExpression>();
accessIndex->expr = std::move(indexedExpr);
case 'g':
case 'y':
case 't':
swizzle->components[i] = SwizzleComponent::Second;
break;
case 'b':
case 'z':
case 'p':
swizzle->components[i] = SwizzleComponent::Third;
break;
case 'a':
case 'w':
case 'q':
swizzle->components[i] = SwizzleComponent::Fourth;
break;
accessIndexPtr = accessIndex.get();
indexedExpr = std::move(accessIndex);
}
else
accessIndexPtr = static_cast<AccessIndexExpression*>(indexedExpr.get());
std::size_t structIndex = ResolveStruct(exprType);
assert(structIndex < m_structs.size());
const StructDescription& s = m_structs[structIndex];
auto it = std::find_if(s.members.begin(), s.members.end(), [&](const auto& field) { return field.name == identifier; });
if (it == s.members.end())
throw AstError{ "unknown field " + identifier };
accessIndexPtr->indices.push_back(ShaderBuilder::Constant(Int32(std::distance(s.members.begin(), it))));
accessIndexPtr->cachedExpressionType = ResolveType(it->type);
}
else if (IsVectorType(exprType))
{
// Swizzle expression
const VectorType& swizzledVec = std::get<VectorType>(exprType);
return swizzle;
auto swizzle = std::make_unique<SwizzleExpression>();
swizzle->expression = std::move(indexedExpr);
if (node.identifiers.size() - i != 1)
throw AstError{ "invalid swizzle" };
const std::string& swizzleStr = node.identifiers[i];
if (swizzleStr.empty() || swizzleStr.size() > swizzle->components.size())
throw AstError{ "invalid swizzle" };
swizzle->componentCount = swizzleStr.size();
if (swizzle->componentCount > 1)
swizzle->cachedExpressionType = VectorType{ swizzle->componentCount, swizzledVec.type };
else
swizzle->cachedExpressionType = swizzledVec.type;
for (std::size_t j = 0; j < swizzle->componentCount; ++j)
{
switch (swizzleStr[j])
{
case 'r':
case 'x':
case 's':
swizzle->components[j] = SwizzleComponent::First;
break;
case 'g':
case 'y':
case 't':
swizzle->components[j] = SwizzleComponent::Second;
break;
case 'b':
case 'z':
case 'p':
swizzle->components[j] = SwizzleComponent::Third;
break;
case 'a':
case 'w':
case 'q':
swizzle->components[j] = SwizzleComponent::Fourth;
break;
}
}
indexedExpr = std::move(swizzle);
}
else
throw AstError{ "unexpected type (only struct and vectors can be indexed with identifiers)" }; //< TODO: Add support for arrays
}
// Transform to AccessIndexExpression
auto accessMemberIndex = std::make_unique<AccessIndexExpression>();
accessMemberIndex->expr = std::move(structExpr);
return indexedExpr;
}
StackArray<std::size_t> structIndices = NazaraStackArrayNoInit(std::size_t, node.memberIdentifiers.size());
ExpressionPtr SanitizeVisitor::Clone(AccessIndexExpression& node)
{
MandatoryExpr(node.expr);
for (auto& index : node.indices)
MandatoryExpr(index);
accessMemberIndex->cachedExpressionType = ResolveType(CheckField(exprType, node.memberIdentifiers.data(), node.memberIdentifiers.size(), structIndices.data()));
auto clone = static_unique_pointer_cast<AccessIndexExpression>(AstCloner::Clone(node));
Validate(*clone);
accessMemberIndex->memberIndices.resize(node.memberIdentifiers.size());
for (std::size_t i = 0; i < node.memberIdentifiers.size(); ++i)
{
std::size_t structIndex = structIndices[i];
assert(structIndex < m_structs.size());
const StructDescription& structDesc = m_structs[structIndex];
auto it = std::find_if(structDesc.members.begin(), structDesc.members.end(), [&](const auto& member) { return member.name == node.memberIdentifiers[i]; });
assert(it != structDesc.members.end());
accessMemberIndex->memberIndices[i] = std::distance(structDesc.members.begin(), it);
}
return accessMemberIndex;
return clone;
}
ExpressionPtr SanitizeVisitor::Clone(AssignExpression& node)
@ -1204,6 +1208,61 @@ namespace Nz::ShaderAst
}
}
void SanitizeVisitor::Validate(AccessIndexExpression& node)
{
if (node.indices.empty())
throw AstError{ "AccessIndexExpression must have at least one index" };
for (auto& index : node.indices)
{
const ShaderAst::ExpressionType& indexType = GetExpressionType(*index);
if (!IsPrimitiveType(indexType) || std::get<PrimitiveType>(indexType) != PrimitiveType::Int32)
throw AstError{ "AccessIndex expects Int32 indices" };
}
ExpressionType exprType = GetExpressionType(*node.expr);
for (std::size_t i = 0; i < node.indices.size(); ++i)
{
if (IsStructType(exprType))
{
auto& indexExpr = node.indices[i];
const ShaderAst::ExpressionType& indexType = GetExpressionType(*indexExpr);
if (indexExpr->GetType() != NodeType::ConstantExpression)
throw AstError{ "struct can only be accessed with constant indices" };
ConstantExpression& constantExpr = static_cast<ConstantExpression&>(*indexExpr);
Int32 index = std::get<Int32>(constantExpr.value);
std::size_t structIndex = ResolveStruct(exprType);
assert(structIndex < m_structs.size());
const StructDescription& s = m_structs[structIndex];
exprType = ResolveType(s.members[index].type);
}
else if (IsMatrixType(exprType))
{
// Matrix index (ex: mat[2])
const MatrixType& matrixType = std::get<MatrixType>(exprType);
//TODO: Handle row-major matrices
exprType = VectorType{ matrixType.rowCount, matrixType.type };
}
else if (IsVectorType(exprType))
{
// Swizzle expression with one component (ex: vec[2])
const VectorType& swizzledVec = std::get<VectorType>(exprType);
exprType = swizzledVec.type;
}
else
throw AstError{ "unexpected type (only struct, vectors and matrices can be indexed)" }; //< TODO: Add support for arrays
}
node.cachedExpressionType = std::move(exprType);
}
void SanitizeVisitor::Validate(CallFunctionExpression& node, const DeclareFunctionStatement* referenceDeclaration)
{
if (referenceDeclaration->entryStage)

View File

@ -355,11 +355,15 @@ namespace Nz
AppendLine((forward) ? ");" : ")");
}
void GlslWriter::AppendField(std::size_t structIndex, const std::size_t* memberIndices, std::size_t remainingMembers)
void GlslWriter::AppendField(std::size_t structIndex, const ShaderAst::ExpressionPtr* memberIndices, std::size_t remainingMembers)
{
const auto& structDesc = Retrieve(m_currentState->structs, structIndex);
const auto& member = structDesc.members[*memberIndices];
assert((*memberIndices)->GetType() == ShaderAst::NodeType::ConstantExpression);
auto& constantValue = static_cast<ShaderAst::ConstantExpression&>(**memberIndices);
Int32 index = std::get<Int32>(constantValue.value);
const auto& member = structDesc.members[index];
Append(".");
Append(member.name);
@ -655,9 +659,20 @@ namespace Nz
Visit(node.expr, true);
const ShaderAst::ExpressionType& exprType = GetExpressionType(*node.expr);
assert(IsStructType(exprType));
AppendField(std::get<ShaderAst::StructType>(exprType).structIndex, node.memberIndices.data(), node.memberIndices.size());
// For structs, convert indices to field names
if (IsStructType(exprType))
AppendField(std::get<ShaderAst::StructType>(exprType).structIndex, node.indices.data(), node.indices.size());
else
{
// Array access
for (ShaderAst::ExpressionPtr& expr : node.indices)
{
Append("[");
Visit(expr);
Append("]");
}
}
}
void GlslWriter::Visit(ShaderAst::AssignExpression& node)

View File

@ -342,11 +342,15 @@ namespace Nz
AppendLine();
}
void LangWriter::AppendField(std::size_t structIndex, const std::size_t* memberIndices, std::size_t remainingMembers)
void LangWriter::AppendField(std::size_t structIndex, const ShaderAst::ExpressionPtr* memberIndices, std::size_t remainingMembers)
{
const auto& structDesc = Retrieve(m_currentState->structs, structIndex);
const auto& member = structDesc.members[*memberIndices];
assert((*memberIndices)->GetType() == ShaderAst::NodeType::ConstantExpression);
auto& constantValue = static_cast<ShaderAst::ConstantExpression&>(**memberIndices);
Int32 index = std::get<Int32>(constantValue.value);
const auto& member = structDesc.members[index];
Append(".");
Append(member.name);
@ -443,9 +447,20 @@ namespace Nz
Visit(node.expr, true);
const ShaderAst::ExpressionType& exprType = GetExpressionType(*node.expr);
assert(IsStructType(exprType));
AppendField(std::get<ShaderAst::StructType>(exprType).structIndex, node.memberIndices.data(), node.memberIndices.size());
// For structs, convert indices to field names
if (IsStructType(exprType))
AppendField(std::get<ShaderAst::StructType>(exprType).structIndex, node.indices.data(), node.indices.size());
else
{
// Array access
for (ShaderAst::ExpressionPtr& expr : node.indices)
{
Append("[");
Visit(expr);
Append("]");
}
}
}
void LangWriter::Visit(ShaderAst::AssignExpression& node)

File diff suppressed because it is too large Load Diff

View File

@ -3,6 +3,7 @@
// For conditions of distribution and use, see copyright notice in Config.hpp
#include <Nazara/Shader/SpirvExpressionLoad.hpp>
#include <Nazara/Core/StackArray.hpp>
#include <Nazara/Shader/SpirvAstVisitor.hpp>
#include <Nazara/Shader/SpirvBlock.hpp>
#include <Nazara/Shader/SpirvWriter.hpp>
@ -55,28 +56,36 @@ namespace Nz
{
UInt32 pointerType = m_writer.RegisterPointerType(exprType, pointer.storage); //< FIXME
StackArray<UInt32> indexIds = NazaraStackArrayNoInit(UInt32, node.indices.size());
for (std::size_t i = 0; i < node.indices.size(); ++i)
indexIds[i] = m_visitor.EvaluateExpression(node.indices[i]);
m_block.AppendVariadic(SpirvOp::OpAccessChain, [&](const auto& appender)
{
appender(pointerType);
appender(resultId);
appender(pointer.pointerId);
for (std::size_t index : node.memberIndices)
appender(m_writer.GetConstantId(Int32(index)));
for (UInt32 id : indexIds)
appender(id);
});
m_value = Pointer { pointer.storage, resultId, typeId };
},
[&](const Value& value)
{
StackArray<UInt32> indexIds = NazaraStackArrayNoInit(UInt32, node.indices.size());
for (std::size_t i = 0; i < node.indices.size(); ++i)
indexIds[i] = m_visitor.EvaluateExpression(node.indices[i]);
m_block.AppendVariadic(SpirvOp::OpCompositeExtract, [&](const auto& appender)
{
appender(typeId);
appender(resultId);
appender(value.resultId);
for (std::size_t index : node.memberIndices)
appender(m_writer.GetConstantId(Int32(index)));
for (UInt32 id : indexIds)
appender(id);
});
m_value = Value { resultId };

View File

@ -3,6 +3,7 @@
// For conditions of distribution and use, see copyright notice in Config.hpp
#include <Nazara/Shader/SpirvExpressionStore.hpp>
#include <Nazara/Core/StackArray.hpp>
#include <Nazara/Shader/SpirvAstVisitor.hpp>
#include <Nazara/Shader/SpirvBlock.hpp>
#include <Nazara/Shader/SpirvWriter.hpp>
@ -50,14 +51,18 @@ namespace Nz
UInt32 resultId = m_visitor.AllocateResultId();
UInt32 pointerType = m_writer.RegisterPointerType(exprType, pointer.storage); //< FIXME
StackArray<UInt32> indexIds = NazaraStackArrayNoInit(UInt32, node.indices.size());
for (std::size_t i = 0; i < node.indices.size(); ++i)
indexIds[i] = m_visitor.EvaluateExpression(node.indices[i]);
m_block.AppendVariadic(SpirvOp::OpAccessChain, [&](const auto& appender)
{
appender(pointerType);
appender(resultId);
appender(pointer.pointerId);
for (std::size_t index : node.memberIndices)
appender(m_writer.GetConstantId(Int32(index)));
for (UInt32 id : indexIds)
appender(id);
});
m_value = Pointer { pointer.storage, resultId };

View File

@ -75,9 +75,6 @@ namespace Nz
{
AstRecursiveVisitor::Visit(node);
for (std::size_t index : node.memberIndices)
m_constantCache.Register(*m_constantCache.BuildConstant(Int32(index)));
m_constantCache.Register(*m_constantCache.BuildType(node.cachedExpressionType.value()));
}