Refactor SpirV classes

SpirvStatementVisitor was merged with SpirvExpressionLoad
SpirvExpressionLoadAccessMember was renamed SpirvExpressionLoad
This commit is contained in:
Jérôme Leclercq 2020-08-23 21:56:30 +02:00
parent 6c379eff68
commit 77b66620c9
12 changed files with 591 additions and 714 deletions

View File

@ -0,0 +1,58 @@
// Copyright (C) 2020 Jérôme Leclercq
// This file is part of the "Nazara Engine - Shader generator"
// For conditions of distribution and use, see copyright notice in Config.hpp
#pragma once
#ifndef NAZARA_SPIRVEXPRESSIONLOAD_HPP
#define NAZARA_SPIRVEXPRESSIONLOAD_HPP
#include <Nazara/Prerequisites.hpp>
#include <Nazara/Shader/Config.hpp>
#include <Nazara/Shader/ShaderAstVisitorExcept.hpp>
#include <Nazara/Shader/ShaderVarVisitorExcept.hpp>
#include <vector>
namespace Nz
{
class SpirvWriter;
class NAZARA_SHADER_API SpirvAstVisitor : public ShaderAstVisitorExcept
{
public:
inline SpirvAstVisitor(SpirvWriter& writer);
SpirvAstVisitor(const SpirvAstVisitor&) = delete;
SpirvAstVisitor(SpirvAstVisitor&&) = delete;
~SpirvAstVisitor() = default;
UInt32 EvaluateExpression(const ShaderNodes::ExpressionPtr& expr);
using ShaderAstVisitorExcept::Visit;
void Visit(ShaderNodes::AccessMember& node) override;
void Visit(ShaderNodes::AssignOp& node) override;
void Visit(ShaderNodes::BinaryOp& node) override;
void Visit(ShaderNodes::Cast& node) override;
void Visit(ShaderNodes::Constant& node) override;
void Visit(ShaderNodes::DeclareVariable& node) override;
void Visit(ShaderNodes::ExpressionStatement& node) override;
void Visit(ShaderNodes::Identifier& node) override;
void Visit(ShaderNodes::IntrinsicCall& node) override;
void Visit(ShaderNodes::Sample2D& node) override;
void Visit(ShaderNodes::StatementBlock& node) override;
void Visit(ShaderNodes::SwizzleOp& node) override;
SpirvAstVisitor& operator=(const SpirvAstVisitor&) = delete;
SpirvAstVisitor& operator=(SpirvAstVisitor&&) = delete;
private:
void PushResultId(UInt32 value);
UInt32 PopResultId();
std::vector<UInt32> m_resultIds;
SpirvWriter& m_writer;
};
}
#include <Nazara/Shader/SpirvAstVisitor.inl>
#endif

View File

@ -2,12 +2,12 @@
// This file is part of the "Nazara Engine - Shader generator"
// For conditions of distribution and use, see copyright notice in Config.hpp
#include <Nazara/Shader/SpirvStatementVisitor.hpp>
#include <Nazara/Shader/SpirvAstVisitor.hpp>
#include <Nazara/Shader/Debug.hpp>
namespace Nz
{
inline SpirvStatementVisitor::SpirvStatementVisitor(SpirvWriter& writer) :
inline SpirvAstVisitor::SpirvAstVisitor(SpirvWriter& writer) :
m_writer(writer)
{
}

View File

@ -4,13 +4,14 @@
#pragma once
#ifndef NAZARA_SPIRVEXPRESSIONLOAD_HPP
#define NAZARA_SPIRVEXPRESSIONLOAD_HPP
#ifndef NAZARA_SPIRVEXPRESSIONLOADACCESSMEMBER_HPP
#define NAZARA_SPIRVEXPRESSIONLOADACCESSMEMBER_HPP
#include <Nazara/Prerequisites.hpp>
#include <Nazara/Shader/Config.hpp>
#include <Nazara/Shader/ShaderAstVisitorExcept.hpp>
#include <Nazara/Shader/ShaderVarVisitorExcept.hpp>
#include <Nazara/Shader/SpirvData.hpp>
#include <vector>
namespace Nz
@ -25,37 +26,35 @@ namespace Nz
SpirvExpressionLoad(SpirvExpressionLoad&&) = delete;
~SpirvExpressionLoad() = default;
UInt32 EvaluateExpression(const ShaderNodes::ExpressionPtr& expr);
UInt32 Evaluate(ShaderNodes::Expression& node);
using ShaderAstVisitorExcept::Visit;
using ShaderAstVisitor::Visit;
void Visit(ShaderNodes::AccessMember& node) override;
void Visit(ShaderNodes::AssignOp& node) override;
void Visit(ShaderNodes::BinaryOp& node) override;
void Visit(ShaderNodes::Cast& node) override;
void Visit(ShaderNodes::Constant& node) override;
void Visit(ShaderNodes::DeclareVariable& node) override;
void Visit(ShaderNodes::ExpressionStatement& node) override;
void Visit(ShaderNodes::Identifier& node) override;
void Visit(ShaderNodes::IntrinsicCall& node) override;
void Visit(ShaderNodes::Sample2D& node) override;
void Visit(ShaderNodes::SwizzleOp& node) override;
using ShaderVarVisitorExcept::Visit;
void Visit(ShaderNodes::BuiltinVariable& var) override;
using ShaderVarVisitor::Visit;
void Visit(ShaderNodes::InputVariable& var) override;
void Visit(ShaderNodes::LocalVariable& var) override;
void Visit(ShaderNodes::ParameterVariable& var) override;
void Visit(ShaderNodes::UniformVariable& var) override;
SpirvExpressionLoad& operator=(const SpirvExpressionLoad&) = delete;
SpirvExpressionLoad& operator=(SpirvExpressionLoad&&) = delete;
private:
void PushResultId(UInt32 value);
UInt32 PopResultId();
struct Pointer
{
SpirvStorageClass storage;
UInt32 resultId;
UInt32 pointedTypeId;
};
struct Value
{
UInt32 resultId;
};
std::vector<UInt32> m_resultIds;
SpirvWriter& m_writer;
std::variant<std::monostate, Pointer, Value> m_value;
};
}

View File

@ -1,62 +0,0 @@
// Copyright (C) 2020 Jérôme Leclercq
// This file is part of the "Nazara Engine - Shader generator"
// For conditions of distribution and use, see copyright notice in Config.hpp
#pragma once
#ifndef NAZARA_SPIRVEXPRESSIONLOADACCESSMEMBER_HPP
#define NAZARA_SPIRVEXPRESSIONLOADACCESSMEMBER_HPP
#include <Nazara/Prerequisites.hpp>
#include <Nazara/Shader/Config.hpp>
#include <Nazara/Shader/ShaderAstVisitorExcept.hpp>
#include <Nazara/Shader/ShaderVarVisitorExcept.hpp>
#include <Nazara/Shader/SpirvData.hpp>
#include <vector>
namespace Nz
{
class SpirvWriter;
class NAZARA_SHADER_API SpirvExpressionLoadAccessMember : public ShaderAstVisitorExcept, public ShaderVarVisitorExcept
{
public:
inline SpirvExpressionLoadAccessMember(SpirvWriter& writer);
SpirvExpressionLoadAccessMember(const SpirvExpressionLoadAccessMember&) = delete;
SpirvExpressionLoadAccessMember(SpirvExpressionLoadAccessMember&&) = delete;
~SpirvExpressionLoadAccessMember() = default;
UInt32 EvaluateExpression(ShaderNodes::AccessMember& expr);
using ShaderAstVisitor::Visit;
void Visit(ShaderNodes::AccessMember& node) override;
void Visit(ShaderNodes::Identifier& node) override;
using ShaderVarVisitor::Visit;
void Visit(ShaderNodes::InputVariable& var) override;
void Visit(ShaderNodes::UniformVariable& var) override;
SpirvExpressionLoadAccessMember& operator=(const SpirvExpressionLoadAccessMember&) = delete;
SpirvExpressionLoadAccessMember& operator=(SpirvExpressionLoadAccessMember&&) = delete;
private:
struct Pointer
{
SpirvStorageClass storage;
UInt32 resultId;
UInt32 pointedTypeId;
};
struct Value
{
UInt32 resultId;
};
SpirvWriter& m_writer;
std::variant<std::monostate, Pointer, Value> m_value;
};
}
#include <Nazara/Shader/SpirvExpressionLoadAccessMember.inl>
#endif

View File

@ -1,16 +0,0 @@
// Copyright (C) 2020 Jérôme Leclercq
// This file is part of the "Nazara Engine - Shader generator"
// For conditions of distribution and use, see copyright notice in Config.hpp
#include <Nazara/Shader/SpirvExpressionLoadAccessMember.hpp>
#include <Nazara/Shader/Debug.hpp>
namespace Nz
{
inline SpirvExpressionLoadAccessMember::SpirvExpressionLoadAccessMember(SpirvWriter& writer) :
m_writer(writer)
{
}
}
#include <Nazara/Shader/DebugOff.hpp>

View File

@ -1,43 +0,0 @@
// Copyright (C) 2020 Jérôme Leclercq
// This file is part of the "Nazara Engine - Shader generator"
// For conditions of distribution and use, see copyright notice in Config.hpp
#pragma once
#ifndef NAZARA_SPIRVSTATEMENTVISITOR_HPP
#define NAZARA_SPIRVSTATEMENTVISITOR_HPP
#include <Nazara/Prerequisites.hpp>
#include <Nazara/Shader/Config.hpp>
#include <Nazara/Shader/ShaderAstVisitorExcept.hpp>
namespace Nz
{
class SpirvWriter;
class NAZARA_SHADER_API SpirvStatementVisitor : public ShaderAstVisitorExcept
{
public:
inline SpirvStatementVisitor(SpirvWriter& writer);
SpirvStatementVisitor(const SpirvStatementVisitor&) = delete;
SpirvStatementVisitor(SpirvStatementVisitor&&) = delete;
~SpirvStatementVisitor() = default;
using ShaderAstVisitor::Visit;
void Visit(ShaderNodes::AssignOp& node) override;
void Visit(ShaderNodes::Branch& node) override;
void Visit(ShaderNodes::DeclareVariable& node) override;
void Visit(ShaderNodes::ExpressionStatement& node) override;
void Visit(ShaderNodes::StatementBlock& node) override;
SpirvStatementVisitor& operator=(const SpirvStatementVisitor&) = delete;
SpirvStatementVisitor& operator=(SpirvStatementVisitor&&) = delete;
private:
SpirvWriter& m_writer;
};
}
#include <Nazara/Shader/SpirvStatementVisitor.inl>
#endif

View File

@ -25,10 +25,10 @@ namespace Nz
class NAZARA_SHADER_API SpirvWriter
{
friend class SpirvAstVisitor;
friend class SpirvExpressionLoad;
friend class SpirvExpressionLoadAccessMember;
friend class SpirvExpressionStore;
friend class SpirvStatementVisitor;
friend class SpirvVisitor;
public:
struct Environment;

View File

@ -0,0 +1,433 @@
// Copyright (C) 2020 Jérôme Leclercq
// This file is part of the "Nazara Engine - Shader generator"
// For conditions of distribution and use, see copyright notice in Config.hpp
#include <Nazara/Shader/SpirvAstVisitor.hpp>
#include <Nazara/Core/StackVector.hpp>
#include <Nazara/Shader/SpirvSection.hpp>
#include <Nazara/Shader/SpirvExpressionLoad.hpp>
#include <Nazara/Shader/SpirvExpressionStore.hpp>
#include <Nazara/Shader/SpirvWriter.hpp>
#include <Nazara/Shader/Debug.hpp>
namespace Nz
{
UInt32 SpirvAstVisitor::EvaluateExpression(const ShaderNodes::ExpressionPtr& expr)
{
Visit(expr);
return PopResultId();
}
void SpirvAstVisitor::Visit(ShaderNodes::AccessMember& node)
{
SpirvExpressionLoad accessMemberVisitor(m_writer);
PushResultId(accessMemberVisitor.Evaluate(node));
}
void SpirvAstVisitor::Visit(ShaderNodes::AssignOp& node)
{
UInt32 resultId = EvaluateExpression(node.right);
SpirvExpressionStore storeVisitor(m_writer);
storeVisitor.Store(node.left, resultId);
PushResultId(resultId);
}
void SpirvAstVisitor::Visit(ShaderNodes::BinaryOp& node)
{
ShaderExpressionType resultExprType = node.GetExpressionType();
assert(std::holds_alternative<ShaderNodes::BasicType>(resultExprType));
const ShaderExpressionType& leftExprType = node.left->GetExpressionType();
assert(std::holds_alternative<ShaderNodes::BasicType>(leftExprType));
const ShaderExpressionType& rightExprType = node.right->GetExpressionType();
assert(std::holds_alternative<ShaderNodes::BasicType>(rightExprType));
ShaderNodes::BasicType resultType = std::get<ShaderNodes::BasicType>(resultExprType);
ShaderNodes::BasicType leftType = std::get<ShaderNodes::BasicType>(leftExprType);
ShaderNodes::BasicType rightType = std::get<ShaderNodes::BasicType>(rightExprType);
UInt32 leftOperand = EvaluateExpression(node.left);
UInt32 rightOperand = EvaluateExpression(node.right);
UInt32 resultId = m_writer.AllocateResultId();
bool swapOperands = false;
SpirvOp op = [&]
{
switch (node.op)
{
case ShaderNodes::BinaryType::Add:
{
switch (leftType)
{
case ShaderNodes::BasicType::Float1:
case ShaderNodes::BasicType::Float2:
case ShaderNodes::BasicType::Float3:
case ShaderNodes::BasicType::Float4:
case ShaderNodes::BasicType::Mat4x4:
return SpirvOp::OpFAdd;
case ShaderNodes::BasicType::Int1:
case ShaderNodes::BasicType::Int2:
case ShaderNodes::BasicType::Int3:
case ShaderNodes::BasicType::Int4:
case ShaderNodes::BasicType::UInt1:
case ShaderNodes::BasicType::UInt2:
case ShaderNodes::BasicType::UInt3:
case ShaderNodes::BasicType::UInt4:
return SpirvOp::OpIAdd;
case ShaderNodes::BasicType::Boolean:
case ShaderNodes::BasicType::Sampler2D:
case ShaderNodes::BasicType::Void:
break;
}
}
case ShaderNodes::BinaryType::Substract:
{
switch (leftType)
{
case ShaderNodes::BasicType::Float1:
case ShaderNodes::BasicType::Float2:
case ShaderNodes::BasicType::Float3:
case ShaderNodes::BasicType::Float4:
case ShaderNodes::BasicType::Mat4x4:
return SpirvOp::OpFSub;
case ShaderNodes::BasicType::Int1:
case ShaderNodes::BasicType::Int2:
case ShaderNodes::BasicType::Int3:
case ShaderNodes::BasicType::Int4:
case ShaderNodes::BasicType::UInt1:
case ShaderNodes::BasicType::UInt2:
case ShaderNodes::BasicType::UInt3:
case ShaderNodes::BasicType::UInt4:
return SpirvOp::OpISub;
case ShaderNodes::BasicType::Boolean:
case ShaderNodes::BasicType::Sampler2D:
case ShaderNodes::BasicType::Void:
break;
}
}
case ShaderNodes::BinaryType::Divide:
{
switch (leftType)
{
case ShaderNodes::BasicType::Float1:
case ShaderNodes::BasicType::Float2:
case ShaderNodes::BasicType::Float3:
case ShaderNodes::BasicType::Float4:
case ShaderNodes::BasicType::Mat4x4:
return SpirvOp::OpFDiv;
case ShaderNodes::BasicType::Int1:
case ShaderNodes::BasicType::Int2:
case ShaderNodes::BasicType::Int3:
case ShaderNodes::BasicType::Int4:
return SpirvOp::OpSDiv;
case ShaderNodes::BasicType::UInt1:
case ShaderNodes::BasicType::UInt2:
case ShaderNodes::BasicType::UInt3:
case ShaderNodes::BasicType::UInt4:
return SpirvOp::OpUDiv;
case ShaderNodes::BasicType::Boolean:
case ShaderNodes::BasicType::Sampler2D:
case ShaderNodes::BasicType::Void:
break;
}
}
case ShaderNodes::BinaryType::Equality:
{
switch (leftType)
{
case ShaderNodes::BasicType::Boolean:
return SpirvOp::OpLogicalEqual;
case ShaderNodes::BasicType::Float1:
case ShaderNodes::BasicType::Float2:
case ShaderNodes::BasicType::Float3:
case ShaderNodes::BasicType::Float4:
case ShaderNodes::BasicType::Mat4x4:
return SpirvOp::OpFOrdEqual;
case ShaderNodes::BasicType::Int1:
case ShaderNodes::BasicType::Int2:
case ShaderNodes::BasicType::Int3:
case ShaderNodes::BasicType::Int4:
case ShaderNodes::BasicType::UInt1:
case ShaderNodes::BasicType::UInt2:
case ShaderNodes::BasicType::UInt3:
case ShaderNodes::BasicType::UInt4:
return SpirvOp::OpIEqual;
case ShaderNodes::BasicType::Sampler2D:
case ShaderNodes::BasicType::Void:
break;
}
}
case ShaderNodes::BinaryType::Multiply:
{
switch (leftType)
{
case ShaderNodes::BasicType::Float1:
{
switch (rightType)
{
case ShaderNodes::BasicType::Float1:
return SpirvOp::OpFMul;
case ShaderNodes::BasicType::Float2:
case ShaderNodes::BasicType::Float3:
case ShaderNodes::BasicType::Float4:
swapOperands = true;
return SpirvOp::OpVectorTimesScalar;
case ShaderNodes::BasicType::Mat4x4:
swapOperands = true;
return SpirvOp::OpMatrixTimesScalar;
default:
break;
}
break;
}
case ShaderNodes::BasicType::Float2:
case ShaderNodes::BasicType::Float3:
case ShaderNodes::BasicType::Float4:
{
switch (rightType)
{
case ShaderNodes::BasicType::Float1:
return SpirvOp::OpVectorTimesScalar;
case ShaderNodes::BasicType::Float2:
case ShaderNodes::BasicType::Float3:
case ShaderNodes::BasicType::Float4:
return SpirvOp::OpFMul;
case ShaderNodes::BasicType::Mat4x4:
return SpirvOp::OpVectorTimesMatrix;
default:
break;
}
break;
}
case ShaderNodes::BasicType::Int1:
case ShaderNodes::BasicType::Int2:
case ShaderNodes::BasicType::Int3:
case ShaderNodes::BasicType::Int4:
case ShaderNodes::BasicType::UInt1:
case ShaderNodes::BasicType::UInt2:
case ShaderNodes::BasicType::UInt3:
case ShaderNodes::BasicType::UInt4:
return SpirvOp::OpIMul;
case ShaderNodes::BasicType::Mat4x4:
{
switch (rightType)
{
case ShaderNodes::BasicType::Float1: return SpirvOp::OpMatrixTimesScalar;
case ShaderNodes::BasicType::Float4: return SpirvOp::OpMatrixTimesVector;
case ShaderNodes::BasicType::Mat4x4: return SpirvOp::OpMatrixTimesMatrix;
default:
break;
}
break;
}
default:
break;
}
break;
}
}
assert(false);
throw std::runtime_error("unexpected binary operation");
}();
if (swapOperands)
std::swap(leftOperand, rightOperand);
m_writer.GetInstructions().Append(op, m_writer.GetTypeId(resultType), resultId, leftOperand, rightOperand);
PushResultId(resultId);
}
void SpirvAstVisitor::Visit(ShaderNodes::Cast& node)
{
const ShaderExpressionType& targetExprType = node.exprType;
assert(std::holds_alternative<ShaderNodes::BasicType>(targetExprType));
ShaderNodes::BasicType targetType = std::get<ShaderNodes::BasicType>(targetExprType);
StackVector<UInt32> exprResults = NazaraStackVector(UInt32, node.expressions.size());
for (const auto& exprPtr : node.expressions)
{
if (!exprPtr)
break;
exprResults.push_back(EvaluateExpression(exprPtr));
}
UInt32 resultId = m_writer.AllocateResultId();
m_writer.GetInstructions().AppendVariadic(SpirvOp::OpCompositeConstruct, [&](const auto& appender)
{
appender(m_writer.GetTypeId(targetType));
appender(resultId);
for (UInt32 exprResultId : exprResults)
appender(exprResultId);
});
PushResultId(resultId);
}
void SpirvAstVisitor::Visit(ShaderNodes::Constant& node)
{
std::visit([&] (const auto& value)
{
PushResultId(m_writer.GetConstantId(value));
}, node.value);
}
void SpirvAstVisitor::Visit(ShaderNodes::DeclareVariable& node)
{
if (node.expression)
{
assert(node.variable->GetType() == ShaderNodes::VariableType::LocalVariable);
const auto& localVar = static_cast<const ShaderNodes::LocalVariable&>(*node.variable);
m_writer.WriteLocalVariable(localVar.name, EvaluateExpression(node.expression));
}
}
void SpirvAstVisitor::Visit(ShaderNodes::ExpressionStatement& node)
{
Visit(node.expression);
PopResultId();
}
void SpirvAstVisitor::Visit(ShaderNodes::Identifier& node)
{
SpirvExpressionLoad loadVisitor(m_writer);
PushResultId(loadVisitor.Evaluate(node));
}
void SpirvAstVisitor::Visit(ShaderNodes::IntrinsicCall& node)
{
switch (node.intrinsic)
{
case ShaderNodes::IntrinsicType::DotProduct:
{
const ShaderExpressionType& vecExprType = node.parameters[0]->GetExpressionType();
assert(std::holds_alternative<ShaderNodes::BasicType>(vecExprType));
ShaderNodes::BasicType vecType = std::get<ShaderNodes::BasicType>(vecExprType);
UInt32 typeId = m_writer.GetTypeId(node.GetComponentType(vecType));
UInt32 vec1 = EvaluateExpression(node.parameters[0]);
UInt32 vec2 = EvaluateExpression(node.parameters[1]);
UInt32 resultId = m_writer.AllocateResultId();
m_writer.GetInstructions().Append(SpirvOp::OpDot, typeId, resultId, vec1, vec2);
PushResultId(resultId);
break;
}
case ShaderNodes::IntrinsicType::CrossProduct:
default:
throw std::runtime_error("not yet implemented");
}
}
void SpirvAstVisitor::Visit(ShaderNodes::Sample2D& node)
{
UInt32 typeId = m_writer.GetTypeId(ShaderNodes::BasicType::Float4);
UInt32 samplerId = EvaluateExpression(node.sampler);
UInt32 coordinatesId = EvaluateExpression(node.coordinates);
UInt32 resultId = m_writer.AllocateResultId();
m_writer.GetInstructions().Append(SpirvOp::OpImageSampleImplicitLod, typeId, resultId, samplerId, coordinatesId);
PushResultId(resultId);
}
void SpirvAstVisitor::Visit(ShaderNodes::StatementBlock& node)
{
for (auto& statement : node.statements)
Visit(statement);
}
void SpirvAstVisitor::Visit(ShaderNodes::SwizzleOp& node)
{
const ShaderExpressionType& targetExprType = node.GetExpressionType();
assert(std::holds_alternative<ShaderNodes::BasicType>(targetExprType));
ShaderNodes::BasicType targetType = std::get<ShaderNodes::BasicType>(targetExprType);
UInt32 exprResultId = EvaluateExpression(node.expression);
UInt32 resultId = m_writer.AllocateResultId();
if (node.componentCount > 1)
{
// Swizzling is implemented via SpirvOp::OpVectorShuffle using the same vector twice as operands
m_writer.GetInstructions().AppendVariadic(SpirvOp::OpVectorShuffle, [&](const auto& appender)
{
appender(m_writer.GetTypeId(targetType));
appender(resultId);
appender(exprResultId);
appender(exprResultId);
for (std::size_t i = 0; i < node.componentCount; ++i)
appender(UInt32(node.components[0]) - UInt32(node.components[i]));
});
}
else
{
// Extract a single component from the vector
assert(node.componentCount == 1);
m_writer.GetInstructions().Append(SpirvOp::OpCompositeExtract, m_writer.GetTypeId(targetType), resultId, exprResultId, UInt32(node.components[0]) - UInt32(ShaderNodes::SwizzleComponent::First) );
}
PushResultId(resultId);
}
void SpirvAstVisitor::PushResultId(UInt32 value)
{
m_resultIds.push_back(value);
}
UInt32 SpirvAstVisitor::PopResultId()
{
if (m_resultIds.empty())
throw std::runtime_error("invalid operation");
UInt32 resultId = m_resultIds.back();
m_resultIds.pop_back();
return resultId;
}
}

View File

@ -5,317 +5,88 @@
#include <Nazara/Shader/SpirvExpressionLoad.hpp>
#include <Nazara/Core/StackVector.hpp>
#include <Nazara/Shader/SpirvSection.hpp>
#include <Nazara/Shader/SpirvExpressionLoadAccessMember.hpp>
#include <Nazara/Shader/SpirvExpressionStore.hpp>
#include <Nazara/Shader/SpirvWriter.hpp>
#include <Nazara/Shader/Debug.hpp>
namespace Nz
{
namespace
{
template<class... Ts> struct overloaded : Ts... { using Ts::operator()...; };
template<class... Ts> overloaded(Ts...)->overloaded<Ts...>;
}
UInt32 SpirvExpressionLoad::Evaluate(ShaderNodes::Expression& node)
{
node.Visit(*this);
return std::visit(overloaded
{
[&](const Pointer& pointer) -> UInt32
{
UInt32 resultId = m_writer.AllocateResultId();
m_writer.GetInstructions().Append(SpirvOp::OpLoad, pointer.pointedTypeId, resultId, pointer.resultId);
return resultId;
},
[&](const Value& value) -> UInt32
{
return value.resultId;
},
[this](std::monostate) -> UInt32
{
throw std::runtime_error("an internal error occurred");
}
}, m_value);
}
void SpirvExpressionLoad::Visit(ShaderNodes::AccessMember& node)
{
SpirvExpressionLoadAccessMember accessMemberVisitor(m_writer);
PushResultId(accessMemberVisitor.EvaluateExpression(node));
}
Visit(node.structExpr);
void SpirvExpressionLoad::Visit(ShaderNodes::AssignOp& node)
{
SpirvExpressionLoad loadVisitor(m_writer);
SpirvExpressionStore storeVisitor(m_writer);
storeVisitor.Store(node.left, EvaluateExpression(node.right));
}
void SpirvExpressionLoad::Visit(ShaderNodes::BinaryOp& node)
{
ShaderExpressionType resultExprType = node.GetExpressionType();
assert(std::holds_alternative<ShaderNodes::BasicType>(resultExprType));
const ShaderExpressionType& leftExprType = node.left->GetExpressionType();
assert(std::holds_alternative<ShaderNodes::BasicType>(leftExprType));
const ShaderExpressionType& rightExprType = node.right->GetExpressionType();
assert(std::holds_alternative<ShaderNodes::BasicType>(rightExprType));
ShaderNodes::BasicType resultType = std::get<ShaderNodes::BasicType>(resultExprType);
ShaderNodes::BasicType leftType = std::get<ShaderNodes::BasicType>(leftExprType);
ShaderNodes::BasicType rightType = std::get<ShaderNodes::BasicType>(rightExprType);
UInt32 leftOperand = EvaluateExpression(node.left);
UInt32 rightOperand = EvaluateExpression(node.right);
UInt32 resultId = m_writer.AllocateResultId();
bool swapOperands = false;
SpirvOp op = [&]
std::visit(overloaded
{
switch (node.op)
[&](const Pointer& pointer)
{
case ShaderNodes::BinaryType::Add:
UInt32 resultId = m_writer.AllocateResultId();
UInt32 pointerType = m_writer.RegisterPointerType(node.exprType, pointer.storage); //< FIXME
UInt32 typeId = m_writer.GetTypeId(node.exprType);
m_writer.GetInstructions().AppendVariadic(SpirvOp::OpAccessChain, [&](const auto& appender)
{
switch (leftType)
{
case ShaderNodes::BasicType::Float1:
case ShaderNodes::BasicType::Float2:
case ShaderNodes::BasicType::Float3:
case ShaderNodes::BasicType::Float4:
case ShaderNodes::BasicType::Mat4x4:
return SpirvOp::OpFAdd;
appender(pointerType);
appender(resultId);
appender(pointer.resultId);
case ShaderNodes::BasicType::Int1:
case ShaderNodes::BasicType::Int2:
case ShaderNodes::BasicType::Int3:
case ShaderNodes::BasicType::Int4:
case ShaderNodes::BasicType::UInt1:
case ShaderNodes::BasicType::UInt2:
case ShaderNodes::BasicType::UInt3:
case ShaderNodes::BasicType::UInt4:
return SpirvOp::OpIAdd;
for (std::size_t index : node.memberIndices)
appender(m_writer.GetConstantId(Int32(index)));
});
case ShaderNodes::BasicType::Boolean:
case ShaderNodes::BasicType::Sampler2D:
case ShaderNodes::BasicType::Void:
break;
}
}
m_value = Pointer { pointer.storage, resultId, typeId };
},
[&](const Value& value)
{
UInt32 resultId = m_writer.AllocateResultId();
UInt32 typeId = m_writer.GetTypeId(node.exprType);
case ShaderNodes::BinaryType::Substract:
m_writer.GetInstructions().AppendVariadic(SpirvOp::OpCompositeExtract, [&](const auto& appender)
{
switch (leftType)
{
case ShaderNodes::BasicType::Float1:
case ShaderNodes::BasicType::Float2:
case ShaderNodes::BasicType::Float3:
case ShaderNodes::BasicType::Float4:
case ShaderNodes::BasicType::Mat4x4:
return SpirvOp::OpFSub;
appender(typeId);
appender(resultId);
appender(value.resultId);
case ShaderNodes::BasicType::Int1:
case ShaderNodes::BasicType::Int2:
case ShaderNodes::BasicType::Int3:
case ShaderNodes::BasicType::Int4:
case ShaderNodes::BasicType::UInt1:
case ShaderNodes::BasicType::UInt2:
case ShaderNodes::BasicType::UInt3:
case ShaderNodes::BasicType::UInt4:
return SpirvOp::OpISub;
for (std::size_t index : node.memberIndices)
appender(m_writer.GetConstantId(Int32(index)));
});
case ShaderNodes::BasicType::Boolean:
case ShaderNodes::BasicType::Sampler2D:
case ShaderNodes::BasicType::Void:
break;
}
}
case ShaderNodes::BinaryType::Divide:
{
switch (leftType)
{
case ShaderNodes::BasicType::Float1:
case ShaderNodes::BasicType::Float2:
case ShaderNodes::BasicType::Float3:
case ShaderNodes::BasicType::Float4:
case ShaderNodes::BasicType::Mat4x4:
return SpirvOp::OpFDiv;
case ShaderNodes::BasicType::Int1:
case ShaderNodes::BasicType::Int2:
case ShaderNodes::BasicType::Int3:
case ShaderNodes::BasicType::Int4:
return SpirvOp::OpSDiv;
case ShaderNodes::BasicType::UInt1:
case ShaderNodes::BasicType::UInt2:
case ShaderNodes::BasicType::UInt3:
case ShaderNodes::BasicType::UInt4:
return SpirvOp::OpUDiv;
case ShaderNodes::BasicType::Boolean:
case ShaderNodes::BasicType::Sampler2D:
case ShaderNodes::BasicType::Void:
break;
}
}
case ShaderNodes::BinaryType::Equality:
{
switch (leftType)
{
case ShaderNodes::BasicType::Boolean:
return SpirvOp::OpLogicalEqual;
case ShaderNodes::BasicType::Float1:
case ShaderNodes::BasicType::Float2:
case ShaderNodes::BasicType::Float3:
case ShaderNodes::BasicType::Float4:
case ShaderNodes::BasicType::Mat4x4:
return SpirvOp::OpFOrdEqual;
case ShaderNodes::BasicType::Int1:
case ShaderNodes::BasicType::Int2:
case ShaderNodes::BasicType::Int3:
case ShaderNodes::BasicType::Int4:
case ShaderNodes::BasicType::UInt1:
case ShaderNodes::BasicType::UInt2:
case ShaderNodes::BasicType::UInt3:
case ShaderNodes::BasicType::UInt4:
return SpirvOp::OpIEqual;
case ShaderNodes::BasicType::Sampler2D:
case ShaderNodes::BasicType::Void:
break;
}
}
case ShaderNodes::BinaryType::Multiply:
{
switch (leftType)
{
case ShaderNodes::BasicType::Float1:
{
switch (rightType)
{
case ShaderNodes::BasicType::Float1:
return SpirvOp::OpFMul;
case ShaderNodes::BasicType::Float2:
case ShaderNodes::BasicType::Float3:
case ShaderNodes::BasicType::Float4:
swapOperands = true;
return SpirvOp::OpVectorTimesScalar;
case ShaderNodes::BasicType::Mat4x4:
swapOperands = true;
return SpirvOp::OpMatrixTimesScalar;
default:
break;
}
break;
}
case ShaderNodes::BasicType::Float2:
case ShaderNodes::BasicType::Float3:
case ShaderNodes::BasicType::Float4:
{
switch (rightType)
{
case ShaderNodes::BasicType::Float1:
return SpirvOp::OpVectorTimesScalar;
case ShaderNodes::BasicType::Float2:
case ShaderNodes::BasicType::Float3:
case ShaderNodes::BasicType::Float4:
return SpirvOp::OpFMul;
case ShaderNodes::BasicType::Mat4x4:
return SpirvOp::OpVectorTimesMatrix;
default:
break;
}
break;
}
case ShaderNodes::BasicType::Int1:
case ShaderNodes::BasicType::Int2:
case ShaderNodes::BasicType::Int3:
case ShaderNodes::BasicType::Int4:
case ShaderNodes::BasicType::UInt1:
case ShaderNodes::BasicType::UInt2:
case ShaderNodes::BasicType::UInt3:
case ShaderNodes::BasicType::UInt4:
return SpirvOp::OpIMul;
case ShaderNodes::BasicType::Mat4x4:
{
switch (rightType)
{
case ShaderNodes::BasicType::Float1: return SpirvOp::OpMatrixTimesScalar;
case ShaderNodes::BasicType::Float4: return SpirvOp::OpMatrixTimesVector;
case ShaderNodes::BasicType::Mat4x4: return SpirvOp::OpMatrixTimesMatrix;
default:
break;
}
break;
}
default:
break;
}
break;
}
m_value = Value { resultId };
},
[this](std::monostate)
{
throw std::runtime_error("an internal error occurred");
}
assert(false);
throw std::runtime_error("unexpected binary operation");
}();
if (swapOperands)
std::swap(leftOperand, rightOperand);
m_writer.GetInstructions().Append(op, m_writer.GetTypeId(resultType), resultId, leftOperand, rightOperand);
PushResultId(resultId);
}
void SpirvExpressionLoad::Visit(ShaderNodes::Cast& node)
{
const ShaderExpressionType& targetExprType = node.exprType;
assert(std::holds_alternative<ShaderNodes::BasicType>(targetExprType));
ShaderNodes::BasicType targetType = std::get<ShaderNodes::BasicType>(targetExprType);
StackVector<UInt32> exprResults = NazaraStackVector(UInt32, node.expressions.size());
for (const auto& exprPtr : node.expressions)
{
if (!exprPtr)
break;
exprResults.push_back(EvaluateExpression(exprPtr));
}
UInt32 resultId = m_writer.AllocateResultId();
m_writer.GetInstructions().AppendVariadic(SpirvOp::OpCompositeConstruct, [&](const auto& appender)
{
appender(m_writer.GetTypeId(targetType));
appender(resultId);
for (UInt32 exprResultId : exprResults)
appender(exprResultId);
});
PushResultId(resultId);
}
void SpirvExpressionLoad::Visit(ShaderNodes::Constant& node)
{
std::visit([&] (const auto& value)
{
PushResultId(m_writer.GetConstantId(value));
}, node.value);
}
void SpirvExpressionLoad::Visit(ShaderNodes::DeclareVariable& node)
{
if (node.expression)
{
assert(node.variable->GetType() == ShaderNodes::VariableType::LocalVariable);
const auto& localVar = static_cast<const ShaderNodes::LocalVariable&>(*node.variable);
m_writer.WriteLocalVariable(localVar.name, EvaluateExpression(node.expression));
}
}
void SpirvExpressionLoad::Visit(ShaderNodes::ExpressionStatement& node)
{
Visit(node.expression);
PopResultId();
}, m_value);
}
void SpirvExpressionLoad::Visit(ShaderNodes::Identifier& node)
@ -323,126 +94,28 @@ namespace Nz
Visit(node.var);
}
void SpirvExpressionLoad::Visit(ShaderNodes::IntrinsicCall& node)
{
switch (node.intrinsic)
{
case ShaderNodes::IntrinsicType::DotProduct:
{
const ShaderExpressionType& vecExprType = node.parameters[0]->GetExpressionType();
assert(std::holds_alternative<ShaderNodes::BasicType>(vecExprType));
ShaderNodes::BasicType vecType = std::get<ShaderNodes::BasicType>(vecExprType);
UInt32 typeId = m_writer.GetTypeId(node.GetComponentType(vecType));
UInt32 vec1 = EvaluateExpression(node.parameters[0]);
UInt32 vec2 = EvaluateExpression(node.parameters[1]);
UInt32 resultId = m_writer.AllocateResultId();
m_writer.GetInstructions().Append(SpirvOp::OpDot, typeId, resultId, vec1, vec2);
PushResultId(resultId);
break;
}
case ShaderNodes::IntrinsicType::CrossProduct:
default:
throw std::runtime_error("not yet implemented");
}
}
void SpirvExpressionLoad::Visit(ShaderNodes::Sample2D& node)
{
UInt32 typeId = m_writer.GetTypeId(ShaderNodes::BasicType::Float4);
UInt32 samplerId = EvaluateExpression(node.sampler);
UInt32 coordinatesId = EvaluateExpression(node.coordinates);
UInt32 resultId = m_writer.AllocateResultId();
m_writer.GetInstructions().Append(SpirvOp::OpImageSampleImplicitLod, typeId, resultId, samplerId, coordinatesId);
PushResultId(resultId);
}
void SpirvExpressionLoad::Visit(ShaderNodes::SwizzleOp& node)
{
const ShaderExpressionType& targetExprType = node.GetExpressionType();
assert(std::holds_alternative<ShaderNodes::BasicType>(targetExprType));
ShaderNodes::BasicType targetType = std::get<ShaderNodes::BasicType>(targetExprType);
UInt32 exprResultId = EvaluateExpression(node.expression);
UInt32 resultId = m_writer.AllocateResultId();
if (node.componentCount > 1)
{
// Swizzling is implemented via SpirvOp::OpVectorShuffle using the same vector twice as operands
m_writer.GetInstructions().AppendVariadic(SpirvOp::OpVectorShuffle, [&](const auto& appender)
{
appender(m_writer.GetTypeId(targetType));
appender(resultId);
appender(exprResultId);
appender(exprResultId);
for (std::size_t i = 0; i < node.componentCount; ++i)
appender(UInt32(node.components[0]) - UInt32(node.components[i]));
});
}
else
{
// Extract a single component from the vector
assert(node.componentCount == 1);
m_writer.GetInstructions().Append(SpirvOp::OpCompositeExtract, m_writer.GetTypeId(targetType), resultId, exprResultId, UInt32(node.components[0]) - UInt32(ShaderNodes::SwizzleComponent::First) );
}
PushResultId(resultId);
}
void SpirvExpressionLoad::Visit(ShaderNodes::BuiltinVariable& /*var*/)
{
throw std::runtime_error("not implemented yet");
}
void SpirvExpressionLoad::Visit(ShaderNodes::InputVariable& var)
{
PushResultId(m_writer.ReadInputVariable(var.name));
auto inputVar = m_writer.GetInputVariable(var.name);
if (auto resultIdOpt = m_writer.ReadVariable(inputVar, SpirvWriter::OnlyCache{}))
m_value = Value{ *resultIdOpt };
else
m_value = Pointer{ SpirvStorageClass::Input, inputVar.varId, inputVar.typeId };
}
void SpirvExpressionLoad::Visit(ShaderNodes::LocalVariable& var)
{
PushResultId(m_writer.ReadLocalVariable(var.name));
}
void SpirvExpressionLoad::Visit(ShaderNodes::ParameterVariable& /*var*/)
{
throw std::runtime_error("not implemented yet");
m_value = Value{ m_writer.ReadLocalVariable(var.name) };
}
void SpirvExpressionLoad::Visit(ShaderNodes::UniformVariable& var)
{
PushResultId(m_writer.ReadUniformVariable(var.name));
}
auto uniformVar = m_writer.GetUniformVariable(var.name);
UInt32 SpirvExpressionLoad::EvaluateExpression(const ShaderNodes::ExpressionPtr& expr)
{
Visit(expr);
return PopResultId();
}
void SpirvExpressionLoad::PushResultId(UInt32 value)
{
m_resultIds.push_back(value);
}
UInt32 SpirvExpressionLoad::PopResultId()
{
if (m_resultIds.empty())
throw std::runtime_error("invalid operation");
UInt32 resultId = m_resultIds.back();
m_resultIds.pop_back();
return resultId;
if (auto resultIdOpt = m_writer.ReadVariable(uniformVar, SpirvWriter::OnlyCache{}))
m_value = Value{ *resultIdOpt };
else
m_value = Pointer{ SpirvStorageClass::Uniform, uniformVar.varId, uniformVar.typeId };
}
}

View File

@ -1,116 +0,0 @@
// Copyright (C) 2020 Jérôme Leclercq
// This file is part of the "Nazara Engine - Shader generator"
// For conditions of distribution and use, see copyright notice in Config.hpp
#include <Nazara/Shader/SpirvExpressionLoadAccessMember.hpp>
#include <Nazara/Core/StackVector.hpp>
#include <Nazara/Shader/SpirvSection.hpp>
#include <Nazara/Shader/SpirvWriter.hpp>
#include <Nazara/Shader/Debug.hpp>
namespace Nz
{
namespace
{
template<class... Ts> struct overloaded : Ts... { using Ts::operator()...; };
template<class... Ts> overloaded(Ts...)->overloaded<Ts...>;
}
UInt32 SpirvExpressionLoadAccessMember::EvaluateExpression(ShaderNodes::AccessMember& expr)
{
Visit(expr);
return std::visit(overloaded
{
[&](const Pointer& pointer) -> UInt32
{
UInt32 resultId = m_writer.AllocateResultId();
m_writer.GetInstructions().Append(SpirvOp::OpLoad, pointer.pointedTypeId, resultId, pointer.resultId);
return resultId;
},
[&](const Value& value) -> UInt32
{
return value.resultId;
},
[this](std::monostate) -> UInt32
{
throw std::runtime_error("an internal error occurred");
}
}, m_value);
}
void SpirvExpressionLoadAccessMember::Visit(ShaderNodes::AccessMember& node)
{
Visit(node.structExpr);
std::visit(overloaded
{
[&](const Pointer& pointer)
{
UInt32 resultId = m_writer.AllocateResultId();
UInt32 pointerType = m_writer.RegisterPointerType(node.exprType, pointer.storage); //< FIXME
UInt32 typeId = m_writer.GetTypeId(node.exprType);
m_writer.GetInstructions().AppendVariadic(SpirvOp::OpAccessChain, [&](const auto& appender)
{
appender(pointerType);
appender(resultId);
appender(pointer.resultId);
for (std::size_t index : node.memberIndices)
appender(m_writer.GetConstantId(Int32(index)));
});
m_value = Pointer { pointer.storage, resultId, typeId };
},
[&](const Value& value)
{
UInt32 resultId = m_writer.AllocateResultId();
UInt32 typeId = m_writer.GetTypeId(node.exprType);
m_writer.GetInstructions().AppendVariadic(SpirvOp::OpCompositeExtract, [&](const auto& appender)
{
appender(typeId);
appender(resultId);
appender(value.resultId);
for (std::size_t index : node.memberIndices)
appender(m_writer.GetConstantId(Int32(index)));
});
m_value = Value { resultId };
},
[this](std::monostate)
{
throw std::runtime_error("an internal error occurred");
}
}, m_value);
}
void SpirvExpressionLoadAccessMember::Visit(ShaderNodes::Identifier& node)
{
Visit(node.var);
}
void SpirvExpressionLoadAccessMember::Visit(ShaderNodes::InputVariable& var)
{
auto inputVar = m_writer.GetInputVariable(var.name);
if (auto resultIdOpt = m_writer.ReadVariable(inputVar, SpirvWriter::OnlyCache{}))
m_value = Value{ *resultIdOpt };
else
m_value = Pointer{ SpirvStorageClass::Input, inputVar.varId, inputVar.typeId };
}
void SpirvExpressionLoadAccessMember::Visit(ShaderNodes::UniformVariable& var)
{
auto uniformVar = m_writer.GetUniformVariable(var.name);
if (auto resultIdOpt = m_writer.ReadVariable(uniformVar, SpirvWriter::OnlyCache{}))
m_value = Value{ *resultIdOpt };
else
m_value = Pointer{ SpirvStorageClass::Uniform, uniformVar.varId, uniformVar.typeId };
}
}

View File

@ -1,49 +0,0 @@
// Copyright (C) 2020 Jérôme Leclercq
// This file is part of the "Nazara Engine - Shader generator"
// For conditions of distribution and use, see copyright notice in Config.hpp
#include <Nazara/Shader/SpirvStatementVisitor.hpp>
#include <Nazara/Shader/SpirvExpressionLoad.hpp>
#include <Nazara/Shader/SpirvExpressionStore.hpp>
#include <Nazara/Shader/SpirvWriter.hpp>
#include <Nazara/Shader/Debug.hpp>
namespace Nz
{
void SpirvStatementVisitor::Visit(ShaderNodes::AssignOp& node)
{
SpirvExpressionLoad loadVisitor(m_writer);
SpirvExpressionStore storeVisitor(m_writer);
storeVisitor.Store(node.left, loadVisitor.EvaluateExpression(node.right));
}
void SpirvStatementVisitor::Visit(ShaderNodes::Branch& node)
{
throw std::runtime_error("not yet implemented");
}
void SpirvStatementVisitor::Visit(ShaderNodes::DeclareVariable& node)
{
if (node.expression)
{
assert(node.variable->GetType() == ShaderNodes::VariableType::LocalVariable);
const auto& localVar = static_cast<const ShaderNodes::LocalVariable&>(*node.variable);
SpirvExpressionLoad loadVisitor(m_writer);
m_writer.WriteLocalVariable(localVar.name, loadVisitor.EvaluateExpression(node.expression));
}
}
void SpirvStatementVisitor::Visit(ShaderNodes::ExpressionStatement& node)
{
SpirvExpressionLoad loadVisitor(m_writer);
loadVisitor.Visit(node.expression);
}
void SpirvStatementVisitor::Visit(ShaderNodes::StatementBlock& node)
{
for (auto& statement : node.statements)
Visit(statement);
}
}

View File

@ -7,10 +7,10 @@
#include <Nazara/Core/StackVector.hpp>
#include <Nazara/Shader/ShaderAstCloner.hpp>
#include <Nazara/Shader/ShaderAstValidator.hpp>
#include <Nazara/Shader/SpirvAstVisitor.hpp>
#include <Nazara/Shader/SpirvConstantCache.hpp>
#include <Nazara/Shader/SpirvData.hpp>
#include <Nazara/Shader/SpirvSection.hpp>
#include <Nazara/Shader/SpirvStatementVisitor.hpp>
#include <tsl/ordered_map.h>
#include <tsl/ordered_set.h>
#include <SpirV/spirv.h>
@ -380,7 +380,7 @@ namespace Nz
state.instructions.Append(SpirvOp::OpFunctionParameter, GetTypeId(param.type), paramResultId);
}
SpirvStatementVisitor visitor(*this);
SpirvAstVisitor visitor(*this);
visitor.Visit(functionStatements[funcIndex]);
if (func.returnType == ShaderNodes::BasicType::Void)