Big SpirVWriter refactor

This commit is contained in:
Jérôme Leclercq
2020-08-23 18:32:28 +02:00
parent 66a14721cb
commit 93de44d293
22 changed files with 1604 additions and 618 deletions

View File

@@ -10,6 +10,7 @@
#include <Nazara/Shader/SpirvConstantCache.hpp>
#include <Nazara/Shader/SpirvData.hpp>
#include <Nazara/Shader/SpirvSection.hpp>
#include <Nazara/Shader/SpirvStatementVisitor.hpp>
#include <tsl/ordered_map.h>
#include <tsl/ordered_set.h>
#include <SpirV/spirv.h>
@@ -154,14 +155,6 @@ namespace Nz
}
}
struct SpirvWriter::ExtVar
{
UInt32 pointerTypeId;
UInt32 typeId;
UInt32 varId;
std::optional<UInt32> valueId;
};
struct SpirvWriter::State
{
State() :
@@ -387,7 +380,8 @@ namespace Nz
state.instructions.Append(SpirvOp::OpFunctionParameter, GetTypeId(param.type), paramResultId);
}
Visit(functionStatements[funcIndex]);
SpirvStatementVisitor visitor(*this);
visitor.Visit(functionStatements[funcIndex]);
if (func.returnType == ShaderNodes::BasicType::Void)
state.instructions.Append(SpirvOp::OpReturn);
@@ -480,12 +474,6 @@ namespace Nz
m_currentState->header.Append(SpirvOp::OpMemoryModel, SpvAddressingModelLogical, SpvMemoryModelGLSL450);
}
UInt32 SpirvWriter::EvaluateExpression(const ShaderNodes::ExpressionPtr& expr)
{
Visit(expr);
return PopResultId();
}
UInt32 SpirvWriter::GetConstantId(const ShaderConstantValue& value) const
{
return m_currentState->constantTypeCache.GetId(*SpirvConstantCache::BuildConstant(value));
@@ -507,6 +495,43 @@ namespace Nz
});
}
auto SpirvWriter::GetBuiltinVariable(ShaderNodes::BuiltinEntry builtin) const -> const ExtVar&
{
auto it = m_currentState->builtinIds.find(builtin);
assert(it != m_currentState->builtinIds.end());
return it->second;
}
auto SpirvWriter::GetInputVariable(const std::string& name) const -> const ExtVar&
{
auto it = m_currentState->inputIds.find(name);
assert(it != m_currentState->inputIds.end());
return it->second;
}
auto SpirvWriter::GetOutputVariable(const std::string& name) const -> const ExtVar&
{
auto it = m_currentState->outputIds.find(name);
assert(it != m_currentState->outputIds.end());
return it->second;
}
auto SpirvWriter::GetUniformVariable(const std::string& name) const -> const ExtVar&
{
auto it = m_currentState->uniformIds.find(name);
assert(it != m_currentState->uniformIds.end());
return it.value();
}
SpirvSection& SpirvWriter::GetInstructions()
{
return m_currentState->instructions;
}
UInt32 SpirvWriter::GetPointerTypeId(const ShaderExpressionType& type, SpirvStorageClass storageClass) const
{
return m_currentState->constantTypeCache.GetId(*SpirvConstantCache::BuildPointerType(*m_context.shader, type, storageClass));
@@ -517,20 +542,53 @@ namespace Nz
return m_currentState->constantTypeCache.GetId(*SpirvConstantCache::BuildType(*m_context.shader, type));
}
void SpirvWriter::PushResultId(UInt32 value)
UInt32 SpirvWriter::ReadInputVariable(const std::string& name)
{
m_currentState->resultIds.push_back(value);
auto it = m_currentState->inputIds.find(name);
assert(it != m_currentState->inputIds.end());
return ReadVariable(it.value());
}
UInt32 SpirvWriter::PopResultId()
std::optional<UInt32> SpirvWriter::ReadInputVariable(const std::string& name, OnlyCache)
{
if (m_currentState->resultIds.empty())
throw std::runtime_error("invalid operation");
auto it = m_currentState->inputIds.find(name);
assert(it != m_currentState->inputIds.end());
UInt32 resultId = m_currentState->resultIds.back();
m_currentState->resultIds.pop_back();
return ReadVariable(it.value(), OnlyCache{});
}
return resultId;
UInt32 SpirvWriter::ReadLocalVariable(const std::string& name)
{
auto it = m_currentState->varToResult.find(name);
assert(it != m_currentState->varToResult.end());
return it->second;
}
std::optional<UInt32> SpirvWriter::ReadLocalVariable(const std::string& name, OnlyCache)
{
auto it = m_currentState->varToResult.find(name);
if (it == m_currentState->varToResult.end())
return {};
return it->second;
}
UInt32 SpirvWriter::ReadUniformVariable(const std::string& name)
{
auto it = m_currentState->uniformIds.find(name);
assert(it != m_currentState->uniformIds.end());
return ReadVariable(it.value());
}
std::optional<UInt32> SpirvWriter::ReadUniformVariable(const std::string& name, OnlyCache)
{
auto it = m_currentState->uniformIds.find(name);
assert(it != m_currentState->uniformIds.end());
return ReadVariable(it.value(), OnlyCache{});
}
UInt32 SpirvWriter::ReadVariable(ExtVar& var)
@@ -546,6 +604,14 @@ namespace Nz
return var.valueId.value();
}
std::optional<UInt32> SpirvWriter::ReadVariable(const ExtVar& var, OnlyCache)
{
if (!var.valueId.has_value())
return {};
return var.valueId.value();
}
UInt32 SpirvWriter::RegisterConstant(const ShaderConstantValue& value)
{
return m_currentState->constantTypeCache.Register(*SpirvConstantCache::BuildConstant(value));
@@ -578,572 +644,10 @@ namespace Nz
return m_currentState->constantTypeCache.Register(*SpirvConstantCache::BuildType(*m_context.shader, type));
}
void SpirvWriter::Visit(ShaderNodes::AccessMember& node)
void SpirvWriter::WriteLocalVariable(std::string name, UInt32 resultId)
{
UInt32 pointerId;
SpirvStorageClass storage;
switch (node.structExpr->GetType())
{
case ShaderNodes::NodeType::Identifier:
{
auto& identifier = static_cast<ShaderNodes::Identifier&>(*node.structExpr);
switch (identifier.var->GetType())
{
case ShaderNodes::VariableType::BuiltinVariable:
{
auto& builtinvar = static_cast<ShaderNodes::BuiltinVariable&>(*identifier.var);
auto it = m_currentState->builtinIds.find(builtinvar.entry);
assert(it != m_currentState->builtinIds.end());
pointerId = it->second.varId;
break;
}
case ShaderNodes::VariableType::InputVariable:
{
auto& inputVar = static_cast<ShaderNodes::InputVariable&>(*identifier.var);
auto it = m_currentState->inputIds.find(inputVar.name);
assert(it != m_currentState->inputIds.end());
storage = SpirvStorageClass::Input;
pointerId = it->second.varId;
break;
}
case ShaderNodes::VariableType::OutputVariable:
{
auto& outputVar = static_cast<ShaderNodes::OutputVariable&>(*identifier.var);
auto it = m_currentState->outputIds.find(outputVar.name);
assert(it != m_currentState->outputIds.end());
storage = SpirvStorageClass::Output;
pointerId = it->second.varId;
break;
}
case ShaderNodes::VariableType::UniformVariable:
{
auto& uniformVar = static_cast<ShaderNodes::UniformVariable&>(*identifier.var);
auto it = m_currentState->uniformIds.find(uniformVar.name);
assert(it != m_currentState->uniformIds.end());
storage = SpirvStorageClass::Uniform;
pointerId = it->second.varId;
break;
}
case ShaderNodes::VariableType::LocalVariable:
case ShaderNodes::VariableType::ParameterVariable:
default:
throw std::runtime_error("not yet implemented");
}
break;
}
case ShaderNodes::NodeType::SwizzleOp: //< TODO
default:
throw std::runtime_error("not yet implemented");
}
UInt32 memberPointerId = AllocateResultId();
UInt32 pointerType = RegisterPointerType(node.exprType, storage); //< FIXME
UInt32 typeId = GetTypeId(node.exprType);
m_currentState->instructions.AppendVariadic(SpirvOp::OpAccessChain, [&](const auto& appender)
{
appender(pointerType);
appender(memberPointerId);
appender(pointerId);
for (std::size_t index : node.memberIndices)
appender(GetConstantId(Int32(index)));
});
UInt32 resultId = AllocateResultId();
m_currentState->instructions.Append(SpirvOp::OpLoad, typeId, resultId, memberPointerId);
PushResultId(resultId);
}
void SpirvWriter::Visit(ShaderNodes::AssignOp& node)
{
UInt32 result = EvaluateExpression(node.right);
switch (node.left->GetType())
{
case ShaderNodes::NodeType::Identifier:
{
auto& identifier = static_cast<ShaderNodes::Identifier&>(*node.left);
switch (identifier.var->GetType())
{
case ShaderNodes::VariableType::BuiltinVariable:
{
auto& builtinvar = static_cast<ShaderNodes::BuiltinVariable&>(*identifier.var);
auto it = m_currentState->builtinIds.find(builtinvar.entry);
assert(it != m_currentState->builtinIds.end());
m_currentState->instructions.Append(SpirvOp::OpStore, it->second.varId, result);
PushResultId(result);
break;
}
case ShaderNodes::VariableType::OutputVariable:
{
auto& outputVar = static_cast<ShaderNodes::OutputVariable&>(*identifier.var);
auto it = m_currentState->outputIds.find(outputVar.name);
assert(it != m_currentState->outputIds.end());
m_currentState->instructions.Append(SpirvOp::OpStore, it->second.varId, result);
PushResultId(result);
break;
}
case ShaderNodes::VariableType::InputVariable:
case ShaderNodes::VariableType::LocalVariable:
case ShaderNodes::VariableType::ParameterVariable:
case ShaderNodes::VariableType::UniformVariable:
default:
throw std::runtime_error("not yet implemented");
}
break;
}
case ShaderNodes::NodeType::SwizzleOp: //< TODO
default:
throw std::runtime_error("not yet implemented");
}
}
void SpirvWriter::Visit(ShaderNodes::Branch& node)
{
throw std::runtime_error("not yet implemented");
}
void SpirvWriter::Visit(ShaderNodes::BinaryOp& node)
{
ShaderExpressionType resultExprType = node.GetExpressionType();
assert(std::holds_alternative<ShaderNodes::BasicType>(resultExprType));
const ShaderExpressionType& leftExprType = node.left->GetExpressionType();
assert(std::holds_alternative<ShaderNodes::BasicType>(leftExprType));
const ShaderExpressionType& rightExprType = node.right->GetExpressionType();
assert(std::holds_alternative<ShaderNodes::BasicType>(rightExprType));
ShaderNodes::BasicType resultType = std::get<ShaderNodes::BasicType>(resultExprType);
ShaderNodes::BasicType leftType = std::get<ShaderNodes::BasicType>(leftExprType);
ShaderNodes::BasicType rightType = std::get<ShaderNodes::BasicType>(rightExprType);
UInt32 leftOperand = EvaluateExpression(node.left);
UInt32 rightOperand = EvaluateExpression(node.right);
UInt32 resultId = AllocateResultId();
bool swapOperands = false;
SpirvOp op = [&]
{
switch (node.op)
{
case ShaderNodes::BinaryType::Add:
{
switch (leftType)
{
case ShaderNodes::BasicType::Float1:
case ShaderNodes::BasicType::Float2:
case ShaderNodes::BasicType::Float3:
case ShaderNodes::BasicType::Float4:
case ShaderNodes::BasicType::Mat4x4:
return SpirvOp::OpFAdd;
case ShaderNodes::BasicType::Int1:
case ShaderNodes::BasicType::Int2:
case ShaderNodes::BasicType::Int3:
case ShaderNodes::BasicType::Int4:
case ShaderNodes::BasicType::UInt1:
case ShaderNodes::BasicType::UInt2:
case ShaderNodes::BasicType::UInt3:
case ShaderNodes::BasicType::UInt4:
return SpirvOp::OpIAdd;
case ShaderNodes::BasicType::Boolean:
case ShaderNodes::BasicType::Sampler2D:
case ShaderNodes::BasicType::Void:
break;
}
}
case ShaderNodes::BinaryType::Substract:
{
switch (leftType)
{
case ShaderNodes::BasicType::Float1:
case ShaderNodes::BasicType::Float2:
case ShaderNodes::BasicType::Float3:
case ShaderNodes::BasicType::Float4:
case ShaderNodes::BasicType::Mat4x4:
return SpirvOp::OpFSub;
case ShaderNodes::BasicType::Int1:
case ShaderNodes::BasicType::Int2:
case ShaderNodes::BasicType::Int3:
case ShaderNodes::BasicType::Int4:
case ShaderNodes::BasicType::UInt1:
case ShaderNodes::BasicType::UInt2:
case ShaderNodes::BasicType::UInt3:
case ShaderNodes::BasicType::UInt4:
return SpirvOp::OpISub;
case ShaderNodes::BasicType::Boolean:
case ShaderNodes::BasicType::Sampler2D:
case ShaderNodes::BasicType::Void:
break;
}
}
case ShaderNodes::BinaryType::Divide:
{
switch (leftType)
{
case ShaderNodes::BasicType::Float1:
case ShaderNodes::BasicType::Float2:
case ShaderNodes::BasicType::Float3:
case ShaderNodes::BasicType::Float4:
case ShaderNodes::BasicType::Mat4x4:
return SpirvOp::OpFDiv;
case ShaderNodes::BasicType::Int1:
case ShaderNodes::BasicType::Int2:
case ShaderNodes::BasicType::Int3:
case ShaderNodes::BasicType::Int4:
return SpirvOp::OpSDiv;
case ShaderNodes::BasicType::UInt1:
case ShaderNodes::BasicType::UInt2:
case ShaderNodes::BasicType::UInt3:
case ShaderNodes::BasicType::UInt4:
return SpirvOp::OpUDiv;
case ShaderNodes::BasicType::Boolean:
case ShaderNodes::BasicType::Sampler2D:
case ShaderNodes::BasicType::Void:
break;
}
}
case ShaderNodes::BinaryType::Equality:
{
switch (leftType)
{
case ShaderNodes::BasicType::Boolean:
return SpirvOp::OpLogicalEqual;
case ShaderNodes::BasicType::Float1:
case ShaderNodes::BasicType::Float2:
case ShaderNodes::BasicType::Float3:
case ShaderNodes::BasicType::Float4:
case ShaderNodes::BasicType::Mat4x4:
return SpirvOp::OpFOrdEqual;
case ShaderNodes::BasicType::Int1:
case ShaderNodes::BasicType::Int2:
case ShaderNodes::BasicType::Int3:
case ShaderNodes::BasicType::Int4:
case ShaderNodes::BasicType::UInt1:
case ShaderNodes::BasicType::UInt2:
case ShaderNodes::BasicType::UInt3:
case ShaderNodes::BasicType::UInt4:
return SpirvOp::OpIEqual;
case ShaderNodes::BasicType::Sampler2D:
case ShaderNodes::BasicType::Void:
break;
}
}
case ShaderNodes::BinaryType::Multiply:
{
switch (leftType)
{
case ShaderNodes::BasicType::Float1:
{
switch (rightType)
{
case ShaderNodes::BasicType::Float1:
return SpirvOp::OpFMul;
case ShaderNodes::BasicType::Float2:
case ShaderNodes::BasicType::Float3:
case ShaderNodes::BasicType::Float4:
swapOperands = true;
return SpirvOp::OpVectorTimesScalar;
case ShaderNodes::BasicType::Mat4x4:
swapOperands = true;
return SpirvOp::OpMatrixTimesScalar;
default:
break;
}
break;
}
case ShaderNodes::BasicType::Float2:
case ShaderNodes::BasicType::Float3:
case ShaderNodes::BasicType::Float4:
{
switch (rightType)
{
case ShaderNodes::BasicType::Float1:
return SpirvOp::OpVectorTimesScalar;
case ShaderNodes::BasicType::Float2:
case ShaderNodes::BasicType::Float3:
case ShaderNodes::BasicType::Float4:
return SpirvOp::OpFMul;
case ShaderNodes::BasicType::Mat4x4:
return SpirvOp::OpVectorTimesMatrix;
default:
break;
}
break;
}
case ShaderNodes::BasicType::Int1:
case ShaderNodes::BasicType::Int2:
case ShaderNodes::BasicType::Int3:
case ShaderNodes::BasicType::Int4:
case ShaderNodes::BasicType::UInt1:
case ShaderNodes::BasicType::UInt2:
case ShaderNodes::BasicType::UInt3:
case ShaderNodes::BasicType::UInt4:
return SpirvOp::OpIMul;
case ShaderNodes::BasicType::Mat4x4:
{
switch (rightType)
{
case ShaderNodes::BasicType::Float1: return SpirvOp::OpMatrixTimesScalar;
case ShaderNodes::BasicType::Float4: return SpirvOp::OpMatrixTimesVector;
case ShaderNodes::BasicType::Mat4x4: return SpirvOp::OpMatrixTimesMatrix;
default:
break;
}
break;
}
default:
break;
}
break;
}
}
assert(false);
throw std::runtime_error("unexpected binary operation");
}();
if (swapOperands)
std::swap(leftOperand, rightOperand);
m_currentState->instructions.Append(op, GetTypeId(resultType), resultId, leftOperand, rightOperand);
PushResultId(resultId);
}
void SpirvWriter::Visit(ShaderNodes::Cast& node)
{
const ShaderExpressionType& targetExprType = node.exprType;
assert(std::holds_alternative<ShaderNodes::BasicType>(targetExprType));
ShaderNodes::BasicType targetType = std::get<ShaderNodes::BasicType>(targetExprType);
StackVector<UInt32> exprResults = NazaraStackVector(UInt32, node.expressions.size());
for (const auto& exprPtr : node.expressions)
{
if (!exprPtr)
break;
exprResults.push_back(EvaluateExpression(exprPtr));
}
UInt32 resultId = AllocateResultId();
m_currentState->instructions.AppendVariadic(SpirvOp::OpCompositeConstruct, [&](const auto& appender)
{
appender(GetTypeId(targetType));
appender(resultId);
for (UInt32 exprResultId : exprResults)
appender(exprResultId);
});
PushResultId(resultId);
}
void SpirvWriter::Visit(ShaderNodes::Constant& node)
{
std::visit([&] (const auto& value)
{
PushResultId(GetConstantId(value));
}, node.value);
}
void SpirvWriter::Visit(ShaderNodes::DeclareVariable& node)
{
if (node.expression)
{
assert(node.variable->GetType() == ShaderNodes::VariableType::LocalVariable);
const auto& localVar = static_cast<const ShaderNodes::LocalVariable&>(*node.variable);
m_currentState->varToResult[localVar.name] = EvaluateExpression(node.expression);
}
}
void SpirvWriter::Visit(ShaderNodes::ExpressionStatement& node)
{
Visit(node.expression);
PopResultId();
}
void SpirvWriter::Visit(ShaderNodes::Identifier& node)
{
Visit(node.var);
}
void SpirvWriter::Visit(ShaderNodes::IntrinsicCall& node)
{
switch (node.intrinsic)
{
case ShaderNodes::IntrinsicType::DotProduct:
{
const ShaderExpressionType& vecExprType = node.parameters[0]->GetExpressionType();
assert(std::holds_alternative<ShaderNodes::BasicType>(vecExprType));
ShaderNodes::BasicType vecType = std::get<ShaderNodes::BasicType>(vecExprType);
UInt32 typeId = GetTypeId(node.GetComponentType(vecType));
UInt32 vec1 = EvaluateExpression(node.parameters[0]);
UInt32 vec2 = EvaluateExpression(node.parameters[1]);
UInt32 resultId = AllocateResultId();
m_currentState->instructions.Append(SpirvOp::OpDot, typeId, resultId, vec1, vec2);
PushResultId(resultId);
break;
}
case ShaderNodes::IntrinsicType::CrossProduct:
default:
throw std::runtime_error("not yet implemented");
}
}
void SpirvWriter::Visit(ShaderNodes::Sample2D& node)
{
UInt32 typeId = GetTypeId(ShaderNodes::BasicType::Float4);
UInt32 samplerId = EvaluateExpression(node.sampler);
UInt32 coordinatesId = EvaluateExpression(node.coordinates);
UInt32 resultId = AllocateResultId();
m_currentState->instructions.Append(SpirvOp::OpImageSampleImplicitLod, typeId, resultId, samplerId, coordinatesId);
PushResultId(resultId);
}
void SpirvWriter::Visit(ShaderNodes::StatementBlock& node)
{
for (auto& statement : node.statements)
Visit(statement);
}
void SpirvWriter::Visit(ShaderNodes::SwizzleOp& node)
{
const ShaderExpressionType& targetExprType = node.GetExpressionType();
assert(std::holds_alternative<ShaderNodes::BasicType>(targetExprType));
ShaderNodes::BasicType targetType = std::get<ShaderNodes::BasicType>(targetExprType);
UInt32 exprResultId = EvaluateExpression(node.expression);
UInt32 resultId = AllocateResultId();
if (node.componentCount > 1)
{
// Swizzling is implemented via SpirvOp::OpVectorShuffle using the same vector twice as operands
m_currentState->instructions.AppendVariadic(SpirvOp::OpVectorShuffle, [&](const auto& appender)
{
appender(GetTypeId(targetType));
appender(resultId);
appender(exprResultId);
appender(exprResultId);
for (std::size_t i = 0; i < node.componentCount; ++i)
appender(UInt32(node.components[0]) - UInt32(node.components[i]));
});
}
else
{
// Extract a single component from the vector
assert(node.componentCount == 1);
m_currentState->instructions.Append(SpirvOp::OpCompositeExtract, GetTypeId(targetType), resultId, exprResultId, UInt32(node.components[0]) - UInt32(ShaderNodes::SwizzleComponent::First) );
}
PushResultId(resultId);
}
void SpirvWriter::Visit(ShaderNodes::BuiltinVariable& var)
{
throw std::runtime_error("not implemented yet");
}
void SpirvWriter::Visit(ShaderNodes::InputVariable& var)
{
auto it = m_currentState->inputIds.find(var.name);
assert(it != m_currentState->inputIds.end());
PushResultId(ReadVariable(it.value()));
}
void SpirvWriter::Visit(ShaderNodes::LocalVariable& var)
{
auto it = m_currentState->varToResult.find(var.name);
assert(it != m_currentState->varToResult.end());
PushResultId(it->second);
}
void SpirvWriter::Visit(ShaderNodes::OutputVariable& var)
{
auto it = m_currentState->outputIds.find(var.name);
assert(it != m_currentState->outputIds.end());
PushResultId(ReadVariable(it.value()));
}
void SpirvWriter::Visit(ShaderNodes::ParameterVariable& var)
{
throw std::runtime_error("not implemented yet");
}
void SpirvWriter::Visit(ShaderNodes::UniformVariable& var)
{
auto it = m_currentState->uniformIds.find(var.name);
assert(it != m_currentState->uniformIds.end());
PushResultId(ReadVariable(it.value()));
assert(m_currentState);
m_currentState->varToResult.insert_or_assign(std::move(name), resultId);
}
void SpirvWriter::MergeBlocks(std::vector<UInt32>& output, const SpirvSection& from)