Shader: Rework Swizzle and add support for swizzle store in SPIRV
This commit is contained in:
parent
e43a638112
commit
837b72f68e
|
|
@ -128,14 +128,6 @@ namespace Nz
|
|||
UInt32, //< ui32
|
||||
};
|
||||
|
||||
enum class SwizzleComponent
|
||||
{
|
||||
First,
|
||||
Second,
|
||||
Third,
|
||||
Fourth
|
||||
};
|
||||
|
||||
enum class UnaryType
|
||||
{
|
||||
LogicalNot, //< !v
|
||||
|
|
|
|||
|
|
@ -178,7 +178,7 @@ namespace Nz::ShaderAst
|
|||
NodeType GetType() const override;
|
||||
void Visit(AstExpressionVisitor& visitor) override;
|
||||
|
||||
std::array<SwizzleComponent, 4> components;
|
||||
std::array<UInt32, 4> components;
|
||||
std::size_t componentCount;
|
||||
ExpressionPtr expression;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -133,7 +133,7 @@ namespace Nz::ShaderBuilder
|
|||
|
||||
struct Swizzle
|
||||
{
|
||||
inline std::unique_ptr<ShaderAst::SwizzleExpression> operator()(ShaderAst::ExpressionPtr expression, std::vector<ShaderAst::SwizzleComponent> swizzleComponents) const;
|
||||
inline std::unique_ptr<ShaderAst::SwizzleExpression> operator()(ShaderAst::ExpressionPtr expression, std::vector<UInt32> swizzleComponents) const;
|
||||
};
|
||||
|
||||
struct Unary
|
||||
|
|
|
|||
|
|
@ -277,7 +277,7 @@ namespace Nz::ShaderBuilder
|
|||
return returnNode;
|
||||
}
|
||||
|
||||
inline std::unique_ptr<ShaderAst::SwizzleExpression> Impl::Swizzle::operator()(ShaderAst::ExpressionPtr expression, std::vector<ShaderAst::SwizzleComponent> swizzleComponents) const
|
||||
inline std::unique_ptr<ShaderAst::SwizzleExpression> Impl::Swizzle::operator()(ShaderAst::ExpressionPtr expression, std::vector<UInt32> swizzleComponents) const
|
||||
{
|
||||
auto swizzleNode = std::make_unique<ShaderAst::SwizzleExpression>();
|
||||
swizzleNode->expression = std::move(expression);
|
||||
|
|
@ -285,7 +285,10 @@ namespace Nz::ShaderBuilder
|
|||
assert(swizzleComponents.size() <= swizzleNode->components.size());
|
||||
swizzleNode->componentCount = swizzleComponents.size();
|
||||
for (std::size_t i = 0; i < swizzleNode->componentCount; ++i)
|
||||
{
|
||||
assert(swizzleComponents[i] >= 0 && swizzleComponents[i] <= 4);
|
||||
swizzleNode->components[i] = swizzleComponents[i];
|
||||
}
|
||||
|
||||
return swizzleNode;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@
|
|||
#include <Nazara/Shader/Config.hpp>
|
||||
#include <Nazara/Shader/SpirvData.hpp>
|
||||
#include <Nazara/Shader/Ast/AstExpressionVisitorExcept.hpp>
|
||||
#include <Nazara/Shader/Ast/Enums.hpp>
|
||||
|
||||
namespace Nz
|
||||
{
|
||||
|
|
@ -37,21 +38,23 @@ namespace Nz
|
|||
SpirvExpressionStore& operator=(SpirvExpressionStore&&) = delete;
|
||||
|
||||
private:
|
||||
struct LocalVar
|
||||
{
|
||||
std::string varName;
|
||||
};
|
||||
|
||||
struct Pointer
|
||||
{
|
||||
SpirvStorageClass storage;
|
||||
UInt32 pointerId;
|
||||
};
|
||||
|
||||
struct SwizzledPointer : Pointer
|
||||
{
|
||||
ShaderAst::VectorType swizzledType;
|
||||
std::array<UInt32, 4> swizzleIndices;
|
||||
std::size_t componentCount;
|
||||
};
|
||||
|
||||
SpirvAstVisitor& m_visitor;
|
||||
SpirvBlock& m_block;
|
||||
SpirvWriter& m_writer;
|
||||
std::variant<std::monostate, LocalVar, Pointer> m_value;
|
||||
std::variant<std::monostate, Pointer, SwizzledPointer> m_value;
|
||||
};
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -200,26 +200,29 @@ namespace Nz::ShaderAst
|
|||
case 'r':
|
||||
case 'x':
|
||||
case 's':
|
||||
swizzle->components[j] = SwizzleComponent::First;
|
||||
swizzle->components[j] = 0u;
|
||||
break;
|
||||
|
||||
case 'g':
|
||||
case 'y':
|
||||
case 't':
|
||||
swizzle->components[j] = SwizzleComponent::Second;
|
||||
swizzle->components[j] = 1u;
|
||||
break;
|
||||
|
||||
case 'b':
|
||||
case 'z':
|
||||
case 'p':
|
||||
swizzle->components[j] = SwizzleComponent::Third;
|
||||
swizzle->components[j] = 2u;
|
||||
break;
|
||||
|
||||
case 'a':
|
||||
case 'w':
|
||||
case 'q':
|
||||
swizzle->components[j] = SwizzleComponent::Fourth;
|
||||
swizzle->components[j] = 3u;
|
||||
break;
|
||||
|
||||
default:
|
||||
throw AstError{ "unexpected character '" + std::string(swizzleStr) + "' on swizzle " };
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -303,8 +306,8 @@ namespace Nz::ShaderAst
|
|||
auto clone = std::make_unique<CallFunctionExpression>();
|
||||
|
||||
clone->parameters.reserve(node.parameters.size());
|
||||
for (std::size_t i = 0; i < node.parameters.size(); ++i)
|
||||
clone->parameters.push_back(CloneExpression(node.parameters[i]));
|
||||
for (const auto& parameter : node.parameters)
|
||||
clone->parameters.push_back(CloneExpression(parameter));
|
||||
|
||||
std::size_t targetFuncIndex;
|
||||
if (std::holds_alternative<std::string>(node.targetFunction))
|
||||
|
|
@ -495,6 +498,12 @@ namespace Nz::ShaderAst
|
|||
if (node.componentCount > 4)
|
||||
throw AstError{ "Cannot swizzle more than four elements" };
|
||||
|
||||
for (UInt32 swizzleIndex : node.components)
|
||||
{
|
||||
if (swizzleIndex >= 4)
|
||||
throw AstError{ "invalid swizzle" };
|
||||
}
|
||||
|
||||
MandatoryExpr(node.expression);
|
||||
|
||||
auto clone = static_unique_pointer_cast<SwizzleExpression>(AstCloner::Clone(node));
|
||||
|
|
@ -1315,12 +1324,10 @@ namespace Nz::ShaderAst
|
|||
}
|
||||
|
||||
ExpressionType exprType = GetExpressionType(*node.expr);
|
||||
for (std::size_t i = 0; i < node.indices.size(); ++i)
|
||||
for (const auto& indexExpr : node.indices)
|
||||
{
|
||||
if (IsStructType(exprType))
|
||||
{
|
||||
auto& indexExpr = node.indices[i];
|
||||
|
||||
const ShaderAst::ExpressionType& indexType = GetExpressionType(*indexExpr);
|
||||
if (indexExpr->GetType() != NodeType::ConstantValueExpression || indexType != ExpressionType{ PrimitiveType::Int32 })
|
||||
throw AstError{ "struct can only be accessed with constant i32 indices" };
|
||||
|
|
|
|||
|
|
@ -1195,27 +1195,9 @@ namespace Nz
|
|||
Visit(node.expression, true);
|
||||
Append(".");
|
||||
|
||||
const char* componentStr = "xyzw";
|
||||
for (std::size_t i = 0; i < node.componentCount; ++i)
|
||||
{
|
||||
switch (node.components[i])
|
||||
{
|
||||
case ShaderAst::SwizzleComponent::First:
|
||||
Append("x");
|
||||
break;
|
||||
|
||||
case ShaderAst::SwizzleComponent::Second:
|
||||
Append("y");
|
||||
break;
|
||||
|
||||
case ShaderAst::SwizzleComponent::Third:
|
||||
Append("z");
|
||||
break;
|
||||
|
||||
case ShaderAst::SwizzleComponent::Fourth:
|
||||
Append("w");
|
||||
break;
|
||||
}
|
||||
}
|
||||
Append(componentStr[node.components[i]]);
|
||||
}
|
||||
|
||||
void GlslWriter::Visit(ShaderAst::VariableExpression& node)
|
||||
|
|
|
|||
|
|
@ -881,27 +881,9 @@ namespace Nz
|
|||
Visit(node.expression, true);
|
||||
Append(".");
|
||||
|
||||
const char* componentStr = "xyzw";
|
||||
for (std::size_t i = 0; i < node.componentCount; ++i)
|
||||
{
|
||||
switch (node.components[i])
|
||||
{
|
||||
case ShaderAst::SwizzleComponent::First:
|
||||
Append("x");
|
||||
break;
|
||||
|
||||
case ShaderAst::SwizzleComponent::Second:
|
||||
Append("y");
|
||||
break;
|
||||
|
||||
case ShaderAst::SwizzleComponent::Third:
|
||||
Append("z");
|
||||
break;
|
||||
|
||||
case ShaderAst::SwizzleComponent::Fourth:
|
||||
Append("w");
|
||||
break;
|
||||
}
|
||||
}
|
||||
Append(componentStr[node.components[i]]);
|
||||
}
|
||||
|
||||
void LangWriter::Visit(ShaderAst::VariableExpression& node)
|
||||
|
|
|
|||
|
|
@ -906,8 +906,9 @@ namespace Nz
|
|||
|
||||
void SpirvAstVisitor::Visit(ShaderAst::SwizzleExpression& node)
|
||||
{
|
||||
const ShaderAst::ExpressionType& swizzledExpressionType = GetExpressionType(*node.expression);
|
||||
|
||||
UInt32 exprResultId = EvaluateExpression(node.expression);
|
||||
UInt32 resultId = m_writer.AllocateResultId();
|
||||
|
||||
const ShaderAst::ExpressionType& targetExprType = GetExpressionType(node);
|
||||
|
||||
|
|
@ -917,7 +918,10 @@ namespace Nz
|
|||
|
||||
const ShaderAst::VectorType& targetType = std::get<ShaderAst::VectorType>(targetExprType);
|
||||
|
||||
// Swizzling is implemented via SpirvOp::OpVectorShuffle using the same vector twice as operands
|
||||
UInt32 resultId = m_writer.AllocateResultId();
|
||||
if (IsVectorType(swizzledExpressionType))
|
||||
{
|
||||
// Swizzling a vector is implemented via OpVectorShuffle using the same vector twice as operands
|
||||
m_currentBlock->AppendVariadic(SpirvOp::OpVectorShuffle, [&](const auto& appender)
|
||||
{
|
||||
appender(m_writer.GetTypeId(targetType));
|
||||
|
|
@ -926,23 +930,50 @@ namespace Nz
|
|||
appender(exprResultId);
|
||||
|
||||
for (std::size_t i = 0; i < node.componentCount; ++i)
|
||||
appender(UInt32(node.components[i]));
|
||||
appender(node.components[i]);
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
assert(IsPrimitiveType(targetExprType));
|
||||
assert(IsPrimitiveType(swizzledExpressionType));
|
||||
|
||||
// Swizzling a primitive to a vector (a.xxx) can be implemented using OpCompositeConstruct
|
||||
m_currentBlock->AppendVariadic(SpirvOp::OpCompositeConstruct, [&](const auto& appender)
|
||||
{
|
||||
appender(m_writer.GetTypeId(targetType));
|
||||
appender(resultId);
|
||||
|
||||
for (std::size_t i = 0; i < node.componentCount; ++i)
|
||||
appender(exprResultId);
|
||||
});
|
||||
}
|
||||
|
||||
PushResultId(resultId);
|
||||
}
|
||||
else if (IsVectorType(swizzledExpressionType))
|
||||
{
|
||||
assert(IsPrimitiveType(targetExprType));
|
||||
ShaderAst::PrimitiveType targetType = std::get<ShaderAst::PrimitiveType>(targetExprType);
|
||||
|
||||
// Extract a single component from the vector
|
||||
assert(node.componentCount == 1);
|
||||
|
||||
m_currentBlock->Append(SpirvOp::OpCompositeExtract, m_writer.GetTypeId(targetType), resultId, exprResultId, UInt32(node.components[0]) - UInt32(ShaderAst::SwizzleComponent::First) );
|
||||
}
|
||||
UInt32 resultId = m_writer.AllocateResultId();
|
||||
m_currentBlock->Append(SpirvOp::OpCompositeExtract, m_writer.GetTypeId(targetType), resultId, exprResultId, node.components[0]);
|
||||
|
||||
PushResultId(resultId);
|
||||
}
|
||||
else
|
||||
{
|
||||
// Swizzling a primitive to itself (a.x for example), don't do anything
|
||||
assert(IsPrimitiveType(swizzledExpressionType));
|
||||
assert(IsPrimitiveType(targetExprType));
|
||||
assert(node.componentCount == 1);
|
||||
assert(node.components[0] == 0);
|
||||
|
||||
PushResultId(exprResultId);
|
||||
}
|
||||
}
|
||||
|
||||
void SpirvAstVisitor::Visit(ShaderAst::UnaryExpression& node)
|
||||
{
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@
|
|||
#include <Nazara/Shader/SpirvAstVisitor.hpp>
|
||||
#include <Nazara/Shader/SpirvBlock.hpp>
|
||||
#include <Nazara/Shader/SpirvWriter.hpp>
|
||||
#include <numeric>
|
||||
#include <Nazara/Shader/Debug.hpp>
|
||||
|
||||
namespace Nz
|
||||
|
|
@ -27,9 +28,57 @@ namespace Nz
|
|||
{
|
||||
m_block.Append(SpirvOp::OpStore, pointer.pointerId, resultId);
|
||||
},
|
||||
[&](const LocalVar& value)
|
||||
[&](const SwizzledPointer& swizzledPointer)
|
||||
{
|
||||
throw std::runtime_error("not yet implemented");
|
||||
if (swizzledPointer.componentCount > 1)
|
||||
{
|
||||
std::size_t vectorSize = swizzledPointer.swizzledType.componentCount;
|
||||
|
||||
UInt32 exprTypeId = m_writer.GetTypeId(swizzledPointer.swizzledType);
|
||||
|
||||
// Load original value (which will then be shuffled with new value)
|
||||
UInt32 originalVecId = m_visitor.AllocateResultId();
|
||||
m_block.Append(SpirvOp::OpLoad, exprTypeId, originalVecId, swizzledPointer.pointerId);
|
||||
|
||||
// Build a new composite type using OpVectorShuffle and store it
|
||||
StackArray<UInt32> indices = NazaraStackArrayNoInit(UInt32, vectorSize);
|
||||
std::iota(indices.begin(), indices.end(), UInt32(0u)); //< init with regular swizzle (0,1,2,3)
|
||||
|
||||
// override with swizzle components
|
||||
for (std::size_t i = 0; i < swizzledPointer.componentCount; ++i)
|
||||
indices[swizzledPointer.swizzleIndices[i]] = SafeCast<UInt32>(vectorSize + i);
|
||||
|
||||
UInt32 shuffleResultId = m_visitor.AllocateResultId();
|
||||
m_block.AppendVariadic(SpirvOp::OpVectorShuffle, [&](const auto& appender)
|
||||
{
|
||||
appender(exprTypeId);
|
||||
appender(shuffleResultId);
|
||||
|
||||
appender(originalVecId);
|
||||
appender(resultId);
|
||||
|
||||
for (UInt32 index : indices)
|
||||
appender(index);
|
||||
});
|
||||
|
||||
// Store result
|
||||
m_block.Append(SpirvOp::OpStore, swizzledPointer.pointerId, shuffleResultId);
|
||||
}
|
||||
else
|
||||
{
|
||||
const ShaderAst::ExpressionType& exprType = GetExpressionType(*node);
|
||||
|
||||
assert(swizzledPointer.componentCount == 1);
|
||||
|
||||
UInt32 pointerType = m_writer.RegisterPointerType(exprType, swizzledPointer.storage); //< FIXME
|
||||
|
||||
// Access chain
|
||||
UInt32 indexId = m_writer.GetConstantId(SafeCast<Int32>(swizzledPointer.swizzleIndices[0]));
|
||||
|
||||
UInt32 pointerId = m_visitor.AllocateResultId();
|
||||
m_block.Append(SpirvOp::OpAccessChain, pointerType, pointerId, swizzledPointer.pointerId, indexId);
|
||||
m_block.Append(SpirvOp::OpStore, pointerId, resultId);
|
||||
}
|
||||
},
|
||||
[](std::monostate)
|
||||
{
|
||||
|
|
@ -67,10 +116,6 @@ namespace Nz
|
|||
|
||||
m_value = Pointer { pointer.storage, resultId };
|
||||
},
|
||||
[&](const LocalVar& value)
|
||||
{
|
||||
throw std::runtime_error("not yet implemented");
|
||||
},
|
||||
[](std::monostate)
|
||||
{
|
||||
throw std::runtime_error("an internal error occurred");
|
||||
|
|
@ -80,30 +125,34 @@ namespace Nz
|
|||
|
||||
void SpirvExpressionStore::Visit(ShaderAst::SwizzleExpression& node)
|
||||
{
|
||||
if (node.componentCount != 1)
|
||||
throw std::runtime_error("swizzle with more than one component is not yet supported");
|
||||
|
||||
node.expression->Visit(*this);
|
||||
|
||||
const ShaderAst::ExpressionType& exprType = GetExpressionType(node);
|
||||
|
||||
std::visit(overloaded
|
||||
{
|
||||
[&](const Pointer& pointer)
|
||||
{
|
||||
UInt32 resultId = m_visitor.AllocateResultId();
|
||||
UInt32 pointerType = m_writer.RegisterPointerType(exprType, pointer.storage); //< FIXME
|
||||
const auto& expressionType = GetExpressionType(*node.expression);
|
||||
assert(IsVectorType(expressionType));
|
||||
|
||||
Int32 indexCount = UnderlyingCast(node.components[0]) - UnderlyingCast(ShaderAst::SwizzleComponent::First);
|
||||
UInt32 indexId = m_writer.GetConstantId(indexCount);
|
||||
SwizzledPointer swizzledPointer{ pointer };
|
||||
swizzledPointer.swizzledType = std::get<ShaderAst::VectorType>(expressionType);
|
||||
swizzledPointer.componentCount = node.componentCount;
|
||||
swizzledPointer.swizzleIndices = node.components;
|
||||
|
||||
m_block.Append(SpirvOp::OpAccessChain, pointerType, resultId, pointer.pointerId, indexId);
|
||||
|
||||
m_value = Pointer { pointer.storage, resultId };
|
||||
m_value = swizzledPointer;
|
||||
},
|
||||
[&](const LocalVar& value)
|
||||
[&](SwizzledPointer& swizzledPointer)
|
||||
{
|
||||
throw std::runtime_error("not yet implemented");
|
||||
// Swizzle the swizzle, keep common components
|
||||
std::array<UInt32, 4> newIndices;
|
||||
for (std::size_t i = 0; i < node.componentCount; ++i)
|
||||
{
|
||||
assert(node.components[i] < node.componentCount);
|
||||
newIndices[i] = swizzledPointer.swizzleIndices[node.components[i]];
|
||||
}
|
||||
|
||||
swizzledPointer.componentCount = node.componentCount;
|
||||
swizzledPointer.swizzleIndices = newIndices;
|
||||
},
|
||||
[](std::monostate)
|
||||
{
|
||||
|
|
|
|||
|
|
@ -359,7 +359,7 @@ namespace Nz
|
|||
|
||||
for (std::size_t i = 0; i < node.componentCount; ++i)
|
||||
{
|
||||
Int32 indexCount = UnderlyingCast(node.components[i]) - UnderlyingCast(ShaderAst::SwizzleComponent::First);
|
||||
Int32 indexCount = SafeCast<Int32>(node.components[i]);
|
||||
m_constantCache.Register(*m_constantCache.BuildConstant(indexCount));
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue