Shader: First working version on both Vulkan & OpenGL (ES)
This commit is contained in:
225
src/Nazara/Shader/Ast/TransformVisitor.cpp
Normal file
225
src/Nazara/Shader/Ast/TransformVisitor.cpp
Normal file
@@ -0,0 +1,225 @@
|
||||
// Copyright (C) 2020 Jérôme Leclercq
|
||||
// This file is part of the "Nazara Engine - Shader generator"
|
||||
// For conditions of distribution and use, see copyright notice in Config.hpp
|
||||
|
||||
#include <Nazara/Shader/Ast/TransformVisitor.hpp>
|
||||
#include <stdexcept>
|
||||
#include <Nazara/Shader/Debug.hpp>
|
||||
|
||||
namespace Nz::ShaderAst
|
||||
{
|
||||
StatementPtr TransformVisitor::Transform(StatementPtr& nodePtr)
|
||||
{
|
||||
StatementPtr clone;
|
||||
|
||||
PushScope(); //< Global scope
|
||||
{
|
||||
clone = AstCloner::Clone(nodePtr);
|
||||
}
|
||||
PopScope();
|
||||
|
||||
return clone;
|
||||
}
|
||||
|
||||
void TransformVisitor::Visit(BranchStatement& node)
|
||||
{
|
||||
for (auto& cond : node.condStatements)
|
||||
{
|
||||
PushScope();
|
||||
{
|
||||
cond.condition->Visit(*this);
|
||||
cond.statement->Visit(*this);
|
||||
}
|
||||
PopScope();
|
||||
}
|
||||
|
||||
if (node.elseStatement)
|
||||
{
|
||||
PushScope();
|
||||
{
|
||||
node.elseStatement->Visit(*this);
|
||||
}
|
||||
PopScope();
|
||||
}
|
||||
}
|
||||
|
||||
void TransformVisitor::Visit(ConditionalStatement& node)
|
||||
{
|
||||
PushScope();
|
||||
{
|
||||
AstCloner::Visit(node);
|
||||
}
|
||||
PopScope();
|
||||
}
|
||||
|
||||
ExpressionType TransformVisitor::ResolveType(const ExpressionType& exprType)
|
||||
{
|
||||
return std::visit([&](auto&& arg) -> ExpressionType
|
||||
{
|
||||
using T = std::decay_t<decltype(arg)>;
|
||||
|
||||
if constexpr (std::is_same_v<T, NoType> ||
|
||||
std::is_same_v<T, PrimitiveType> ||
|
||||
std::is_same_v<T, MatrixType> ||
|
||||
std::is_same_v<T, SamplerType> ||
|
||||
std::is_same_v<T, StructType> ||
|
||||
std::is_same_v<T, VectorType>)
|
||||
{
|
||||
return exprType;
|
||||
}
|
||||
else if constexpr (std::is_same_v<T, IdentifierType>)
|
||||
{
|
||||
const Identifier* identifier = FindIdentifier(arg.name);
|
||||
assert(identifier);
|
||||
assert(std::holds_alternative<Struct>(identifier->value));
|
||||
|
||||
return StructType{ std::get<Struct>(identifier->value).structIndex };
|
||||
}
|
||||
else if constexpr (std::is_same_v<T, UniformType>)
|
||||
{
|
||||
return std::visit([&](auto&& containedArg)
|
||||
{
|
||||
ExpressionType resolvedType = ResolveType(containedArg);
|
||||
assert(std::holds_alternative<StructType>(resolvedType));
|
||||
|
||||
return UniformType{ std::get<StructType>(resolvedType) };
|
||||
}, arg.containedType);
|
||||
}
|
||||
else
|
||||
static_assert(AlwaysFalse<T>::value, "non-exhaustive visitor");
|
||||
}, exprType);
|
||||
}
|
||||
|
||||
void TransformVisitor::Visit(DeclareExternalStatement& node)
|
||||
{
|
||||
for (auto& extVar : node.externalVars)
|
||||
{
|
||||
extVar.type = ResolveType(extVar.type);
|
||||
|
||||
std::size_t varIndex = RegisterVariable(extVar.name);
|
||||
if (!node.varIndex)
|
||||
node.varIndex = varIndex;
|
||||
}
|
||||
|
||||
AstCloner::Visit(node);
|
||||
}
|
||||
|
||||
void TransformVisitor::Visit(DeclareFunctionStatement& node)
|
||||
{
|
||||
node.funcIndex = m_nextFuncIndex++;
|
||||
node.returnType = ResolveType(node.returnType);
|
||||
for (auto& parameter : node.parameters)
|
||||
parameter.type = ResolveType(parameter.type);
|
||||
|
||||
PushScope();
|
||||
{
|
||||
for (auto& parameter : node.parameters)
|
||||
{
|
||||
std::size_t varIndex = RegisterVariable(parameter.name);
|
||||
if (!node.varIndex)
|
||||
node.varIndex = varIndex;
|
||||
}
|
||||
|
||||
AstCloner::Visit(node);
|
||||
}
|
||||
PopScope();
|
||||
}
|
||||
|
||||
void TransformVisitor::Visit(DeclareStructStatement& node)
|
||||
{
|
||||
node.structIndex = RegisterStruct(node.description.name, node.description);
|
||||
|
||||
AstCloner::Visit(node);
|
||||
}
|
||||
|
||||
void TransformVisitor::Visit(DeclareVariableStatement& node)
|
||||
{
|
||||
node.varType = ResolveType(node.varType);
|
||||
node.varIndex = RegisterVariable(node.varName);
|
||||
|
||||
AstCloner::Visit(node);
|
||||
}
|
||||
|
||||
void TransformVisitor::Visit(MultiStatement& node)
|
||||
{
|
||||
PushScope();
|
||||
{
|
||||
AstCloner::Visit(node);
|
||||
}
|
||||
PopScope();
|
||||
}
|
||||
|
||||
ExpressionPtr TransformVisitor::Clone(AccessMemberIdentifierExpression& node)
|
||||
{
|
||||
auto accessMemberIndex = std::make_unique<AccessMemberIndexExpression>();
|
||||
accessMemberIndex->structExpr = CloneExpression(node.structExpr);
|
||||
accessMemberIndex->cachedExpressionType = node.cachedExpressionType;
|
||||
accessMemberIndex->memberIndices.resize(node.memberIdentifiers.size());
|
||||
|
||||
ExpressionType exprType = GetExpressionType(*node.structExpr);
|
||||
for (std::size_t i = 0; i < node.memberIdentifiers.size(); ++i)
|
||||
{
|
||||
exprType = ResolveType(exprType);
|
||||
assert(std::holds_alternative<StructType>(exprType));
|
||||
|
||||
std::size_t structIndex = std::get<StructType>(exprType).structIndex;
|
||||
assert(structIndex < m_structs.size());
|
||||
const StructDescription& structDesc = m_structs[structIndex];
|
||||
|
||||
auto it = std::find_if(structDesc.members.begin(), structDesc.members.end(), [&](const auto& member) { return member.name == node.memberIdentifiers[i]; });
|
||||
assert(it != structDesc.members.end());
|
||||
|
||||
accessMemberIndex->memberIndices[i] = std::distance(structDesc.members.begin(), it);
|
||||
exprType = it->type;
|
||||
}
|
||||
|
||||
return accessMemberIndex;
|
||||
}
|
||||
|
||||
ExpressionPtr TransformVisitor::Clone(CastExpression& node)
|
||||
{
|
||||
ExpressionPtr expr = AstCloner::Clone(node);
|
||||
|
||||
CastExpression* castExpr = static_cast<CastExpression*>(expr.get());
|
||||
castExpr->targetType = ResolveType(castExpr->targetType);
|
||||
|
||||
return expr;
|
||||
}
|
||||
|
||||
ExpressionPtr TransformVisitor::Clone(IdentifierExpression& node)
|
||||
{
|
||||
const Identifier* identifier = FindIdentifier(node.identifier);
|
||||
assert(identifier);
|
||||
assert(std::holds_alternative<Variable>(identifier->value));
|
||||
|
||||
auto varExpr = std::make_unique<VariableExpression>();
|
||||
varExpr->cachedExpressionType = node.cachedExpressionType;
|
||||
varExpr->variableId = std::get<Variable>(identifier->value).varIndex;
|
||||
|
||||
return varExpr;
|
||||
}
|
||||
|
||||
ExpressionPtr TransformVisitor::CloneExpression(ExpressionPtr& expr)
|
||||
{
|
||||
ExpressionPtr exprPtr = AstCloner::CloneExpression(expr);
|
||||
if (exprPtr)
|
||||
{
|
||||
assert(exprPtr->cachedExpressionType);
|
||||
*exprPtr->cachedExpressionType = ResolveType(*exprPtr->cachedExpressionType);
|
||||
}
|
||||
|
||||
return exprPtr;
|
||||
}
|
||||
|
||||
void TransformVisitor::PushScope()
|
||||
{
|
||||
m_scopeSizes.push_back(m_identifiersInScope.size());
|
||||
}
|
||||
|
||||
void TransformVisitor::PopScope()
|
||||
{
|
||||
assert(!m_scopeSizes.empty());
|
||||
m_identifiersInScope.resize(m_scopeSizes.back());
|
||||
m_scopeSizes.pop_back();
|
||||
}
|
||||
}
|
||||
@@ -330,6 +330,11 @@ namespace Nz
|
||||
}
|
||||
}
|
||||
|
||||
void GlslWriter::Append(const ShaderAst::StructType& structType)
|
||||
{
|
||||
throw std::runtime_error("unexpected struct type");
|
||||
}
|
||||
|
||||
void GlslWriter::Append(const ShaderAst::UniformType& uniformType)
|
||||
{
|
||||
/* TODO */
|
||||
@@ -371,6 +376,7 @@ namespace Nz
|
||||
|
||||
m_currentState->stream << param;
|
||||
}
|
||||
|
||||
template<typename T1, typename T2, typename... Args>
|
||||
void GlslWriter::Append(const T1& firstParam, const T2& secondParam, Args&&... params)
|
||||
{
|
||||
@@ -595,7 +601,7 @@ namespace Nz
|
||||
Append(")");
|
||||
}
|
||||
|
||||
void GlslWriter::Visit(ShaderAst::AccessMemberExpression& node)
|
||||
void GlslWriter::Visit(ShaderAst::AccessMemberIdentifierExpression& node)
|
||||
{
|
||||
Visit(node.structExpr, true);
|
||||
|
||||
@@ -741,8 +747,6 @@ namespace Nz
|
||||
|
||||
void GlslWriter::Visit(ShaderAst::DeclareExternalStatement& node)
|
||||
{
|
||||
|
||||
|
||||
for (const auto& externalVar : node.externalVars)
|
||||
{
|
||||
std::optional<long long> bindingIndex;
|
||||
@@ -774,7 +778,7 @@ namespace Nz
|
||||
|
||||
EnterScope();
|
||||
{
|
||||
const Identifier* identifier = FindIdentifier(std::get<ShaderAst::UniformType>(externalVar.type).containedType.name);
|
||||
const Identifier* identifier = FindIdentifier(std::get<ShaderAst::IdentifierType>(std::get<ShaderAst::UniformType>(externalVar.type).containedType).name);
|
||||
assert(identifier);
|
||||
|
||||
assert(std::holds_alternative<ShaderAst::StructDescription>(identifier->value));
|
||||
|
||||
@@ -42,13 +42,25 @@ namespace Nz::ShaderAst
|
||||
return PopStatement();
|
||||
}
|
||||
|
||||
StatementPtr AstCloner::Clone(DeclareExternalStatement& node)
|
||||
{
|
||||
auto clone = std::make_unique<DeclareExternalStatement>();
|
||||
clone->attributes = node.attributes;
|
||||
clone->externalVars = node.externalVars;
|
||||
clone->varIndex = node.varIndex;
|
||||
|
||||
return clone;
|
||||
}
|
||||
|
||||
StatementPtr AstCloner::Clone(DeclareFunctionStatement& node)
|
||||
{
|
||||
auto clone = std::make_unique<DeclareFunctionStatement>();
|
||||
clone->attributes = node.attributes;
|
||||
clone->funcIndex = node.funcIndex;
|
||||
clone->name = node.name;
|
||||
clone->parameters = node.parameters;
|
||||
clone->returnType = node.returnType;
|
||||
clone->varIndex = node.varIndex;
|
||||
|
||||
clone->statements.reserve(node.statements.size());
|
||||
for (auto& statement : node.statements)
|
||||
@@ -57,15 +69,95 @@ namespace Nz::ShaderAst
|
||||
return clone;
|
||||
}
|
||||
|
||||
void AstCloner::Visit(AccessMemberExpression& node)
|
||||
StatementPtr AstCloner::Clone(DeclareStructStatement& node)
|
||||
{
|
||||
auto clone = std::make_unique<AccessMemberExpression>();
|
||||
auto clone = std::make_unique<DeclareStructStatement>();
|
||||
clone->structIndex = node.structIndex;
|
||||
clone->description = node.description;
|
||||
|
||||
return clone;
|
||||
}
|
||||
|
||||
StatementPtr AstCloner::Clone(DeclareVariableStatement& node)
|
||||
{
|
||||
auto clone = std::make_unique<DeclareVariableStatement>();
|
||||
clone->varIndex = node.varIndex;
|
||||
clone->varName = node.varName;
|
||||
clone->varType = node.varType;
|
||||
clone->initialExpression = CloneExpression(node.initialExpression);
|
||||
|
||||
return clone;
|
||||
}
|
||||
|
||||
ExpressionPtr AstCloner::Clone(AccessMemberIdentifierExpression& node)
|
||||
{
|
||||
auto clone = std::make_unique<AccessMemberIdentifierExpression>();
|
||||
clone->memberIdentifiers = node.memberIdentifiers;
|
||||
clone->structExpr = CloneExpression(node.structExpr);
|
||||
|
||||
clone->cachedExpressionType = node.cachedExpressionType;
|
||||
|
||||
PushExpression(std::move(clone));
|
||||
return clone;
|
||||
}
|
||||
|
||||
ExpressionPtr AstCloner::Clone(AccessMemberIndexExpression& node)
|
||||
{
|
||||
auto clone = std::make_unique<AccessMemberIndexExpression>();
|
||||
clone->memberIndices = node.memberIndices;
|
||||
clone->structExpr = CloneExpression(node.structExpr);
|
||||
|
||||
clone->cachedExpressionType = node.cachedExpressionType;
|
||||
|
||||
return clone;
|
||||
}
|
||||
|
||||
ExpressionPtr AstCloner::Clone(CastExpression& node)
|
||||
{
|
||||
auto clone = std::make_unique<CastExpression>();
|
||||
clone->targetType = node.targetType;
|
||||
|
||||
std::size_t expressionCount = 0;
|
||||
for (auto& expr : node.expressions)
|
||||
{
|
||||
if (!expr)
|
||||
break;
|
||||
|
||||
clone->expressions[expressionCount++] = CloneExpression(expr);
|
||||
}
|
||||
|
||||
clone->cachedExpressionType = node.cachedExpressionType;
|
||||
|
||||
return clone;
|
||||
}
|
||||
|
||||
ExpressionPtr AstCloner::Clone(IdentifierExpression& node)
|
||||
{
|
||||
auto clone = std::make_unique<IdentifierExpression>();
|
||||
clone->identifier = node.identifier;
|
||||
|
||||
clone->cachedExpressionType = node.cachedExpressionType;
|
||||
|
||||
return clone;
|
||||
}
|
||||
|
||||
ExpressionPtr AstCloner::Clone(VariableExpression& node)
|
||||
{
|
||||
auto clone = std::make_unique<VariableExpression>();
|
||||
clone->variableId = node.variableId;
|
||||
|
||||
clone->cachedExpressionType = node.cachedExpressionType;
|
||||
|
||||
return clone;
|
||||
}
|
||||
|
||||
void AstCloner::Visit(AccessMemberIdentifierExpression& node)
|
||||
{
|
||||
return PushExpression(Clone(node));
|
||||
}
|
||||
|
||||
void AstCloner::Visit(AccessMemberIndexExpression& node)
|
||||
{
|
||||
return PushExpression(Clone(node));
|
||||
}
|
||||
|
||||
void AstCloner::Visit(AssignExpression& node)
|
||||
@@ -94,21 +186,7 @@ namespace Nz::ShaderAst
|
||||
|
||||
void AstCloner::Visit(CastExpression& node)
|
||||
{
|
||||
auto clone = std::make_unique<CastExpression>();
|
||||
clone->targetType = node.targetType;
|
||||
|
||||
std::size_t expressionCount = 0;
|
||||
for (auto& expr : node.expressions)
|
||||
{
|
||||
if (!expr)
|
||||
break;
|
||||
|
||||
clone->expressions[expressionCount++] = CloneExpression(expr);
|
||||
}
|
||||
|
||||
clone->cachedExpressionType = node.cachedExpressionType;
|
||||
|
||||
PushExpression(std::move(clone));
|
||||
PushExpression(Clone(node));
|
||||
}
|
||||
|
||||
void AstCloner::Visit(ConditionalExpression& node)
|
||||
@@ -135,12 +213,7 @@ namespace Nz::ShaderAst
|
||||
|
||||
void AstCloner::Visit(IdentifierExpression& node)
|
||||
{
|
||||
auto clone = std::make_unique<IdentifierExpression>();
|
||||
clone->identifier = node.identifier;
|
||||
|
||||
clone->cachedExpressionType = node.cachedExpressionType;
|
||||
|
||||
PushExpression(std::move(clone));
|
||||
PushExpression(Clone(node));
|
||||
}
|
||||
|
||||
void AstCloner::Visit(IntrinsicExpression& node)
|
||||
@@ -169,6 +242,11 @@ namespace Nz::ShaderAst
|
||||
PushExpression(std::move(clone));
|
||||
}
|
||||
|
||||
void AstCloner::Visit(VariableExpression& node)
|
||||
{
|
||||
PushExpression(Clone(node));
|
||||
}
|
||||
|
||||
void AstCloner::Visit(BranchStatement& node)
|
||||
{
|
||||
auto clone = std::make_unique<BranchStatement>();
|
||||
@@ -197,11 +275,7 @@ namespace Nz::ShaderAst
|
||||
|
||||
void AstCloner::Visit(DeclareExternalStatement& node)
|
||||
{
|
||||
auto clone = std::make_unique<DeclareExternalStatement>();
|
||||
clone->attributes = node.attributes;
|
||||
clone->externalVars = node.externalVars;
|
||||
|
||||
PushStatement(std::move(clone));
|
||||
PushStatement(Clone(node));
|
||||
}
|
||||
|
||||
void AstCloner::Visit(DeclareFunctionStatement& node)
|
||||
@@ -211,20 +285,12 @@ namespace Nz::ShaderAst
|
||||
|
||||
void AstCloner::Visit(DeclareStructStatement& node)
|
||||
{
|
||||
auto clone = std::make_unique<DeclareStructStatement>();
|
||||
clone->description = node.description;
|
||||
|
||||
PushStatement(std::move(clone));
|
||||
PushStatement(Clone(node));
|
||||
}
|
||||
|
||||
void AstCloner::Visit(DeclareVariableStatement& node)
|
||||
{
|
||||
auto clone = std::make_unique<DeclareVariableStatement>();
|
||||
clone->varName = node.varName;
|
||||
clone->varType = node.varType;
|
||||
clone->initialExpression = CloneExpression(node.initialExpression);
|
||||
|
||||
PushStatement(std::move(clone));
|
||||
PushStatement(Clone(node));
|
||||
}
|
||||
|
||||
void AstCloner::Visit(DiscardStatement& /*node*/)
|
||||
|
||||
@@ -7,7 +7,12 @@
|
||||
|
||||
namespace Nz::ShaderAst
|
||||
{
|
||||
void AstRecursiveVisitor::Visit(AccessMemberExpression& node)
|
||||
void AstRecursiveVisitor::Visit(AccessMemberIdentifierExpression& node)
|
||||
{
|
||||
node.structExpr->Visit(*this);
|
||||
}
|
||||
|
||||
void AstRecursiveVisitor::Visit(AccessMemberIndexExpression& node)
|
||||
{
|
||||
node.structExpr->Visit(*this);
|
||||
}
|
||||
@@ -62,6 +67,11 @@ namespace Nz::ShaderAst
|
||||
node.expression->Visit(*this);
|
||||
}
|
||||
|
||||
void AstRecursiveVisitor::Visit(VariableExpression& node)
|
||||
{
|
||||
/* Nothing to do */
|
||||
}
|
||||
|
||||
void AstRecursiveVisitor::Visit(BranchStatement& node)
|
||||
{
|
||||
for (auto& cond : node.condStatements)
|
||||
|
||||
@@ -53,7 +53,7 @@ namespace Nz::ShaderAst
|
||||
{
|
||||
ExpressionType subType = extVar.type;
|
||||
if (IsUniformType(subType))
|
||||
subType = IdentifierType{ std::get<UniformType>(subType).containedType };
|
||||
subType = std::get<IdentifierType>(std::get<UniformType>(subType).containedType);
|
||||
|
||||
RegisterVariable(extVar.name, std::move(subType));
|
||||
}
|
||||
|
||||
@@ -33,7 +33,7 @@ namespace Nz::ShaderAst
|
||||
};
|
||||
}
|
||||
|
||||
void AstSerializerBase::Serialize(AccessMemberExpression& node)
|
||||
void AstSerializerBase::Serialize(AccessMemberIdentifierExpression& node)
|
||||
{
|
||||
Node(node.structExpr);
|
||||
|
||||
@@ -42,6 +42,15 @@ namespace Nz::ShaderAst
|
||||
Value(identifier);
|
||||
}
|
||||
|
||||
void AstSerializerBase::Serialize(AccessMemberIndexExpression& node)
|
||||
{
|
||||
Node(node.structExpr);
|
||||
|
||||
Container(node.memberIndices);
|
||||
for (std::size_t& identifier : node.memberIndices)
|
||||
SizeT(identifier);
|
||||
}
|
||||
|
||||
void AstSerializerBase::Serialize(AssignExpression& node)
|
||||
{
|
||||
Enum(node.op);
|
||||
@@ -133,6 +142,11 @@ namespace Nz::ShaderAst
|
||||
Enum(node.components[i]);
|
||||
}
|
||||
|
||||
void AstSerializerBase::Serialize(VariableExpression& node)
|
||||
{
|
||||
SizeT(node.variableId);
|
||||
}
|
||||
|
||||
|
||||
void AstSerializerBase::Serialize(BranchStatement& node)
|
||||
{
|
||||
@@ -364,14 +378,19 @@ namespace Nz::ShaderAst
|
||||
m_stream << UInt32(arg.dim);
|
||||
m_stream << UInt32(arg.sampledType);
|
||||
}
|
||||
else if constexpr (std::is_same_v<T, UniformType>)
|
||||
else if constexpr (std::is_same_v<T, StructType>)
|
||||
{
|
||||
m_stream << UInt8(5);
|
||||
m_stream << arg.containedType.name;
|
||||
m_stream << UInt32(arg.structIndex);
|
||||
}
|
||||
else if constexpr (std::is_same_v<T, UniformType>)
|
||||
{
|
||||
m_stream << UInt8(6);
|
||||
m_stream << std::get<IdentifierType>(arg.containedType).name;
|
||||
}
|
||||
else if constexpr (std::is_same_v<T, VectorType>)
|
||||
{
|
||||
m_stream << UInt8(6);
|
||||
m_stream << UInt8(7);
|
||||
m_stream << UInt32(arg.componentCount);
|
||||
m_stream << UInt32(arg.type);
|
||||
}
|
||||
@@ -621,7 +640,18 @@ namespace Nz::ShaderAst
|
||||
break;
|
||||
}
|
||||
|
||||
case 5: //< UniformType
|
||||
case 5: //< StructType
|
||||
{
|
||||
UInt32 structIndex;
|
||||
Value(structIndex);
|
||||
|
||||
type = StructType{
|
||||
structIndex
|
||||
};
|
||||
break;
|
||||
}
|
||||
|
||||
case 6: //< UniformType
|
||||
{
|
||||
std::string containedType;
|
||||
Value(containedType);
|
||||
@@ -634,7 +664,7 @@ namespace Nz::ShaderAst
|
||||
break;
|
||||
}
|
||||
|
||||
case 6: //< VectorType
|
||||
case 7: //< VectorType
|
||||
{
|
||||
UInt32 componentCount;
|
||||
PrimitiveType componentType;
|
||||
|
||||
@@ -13,7 +13,12 @@ namespace Nz::ShaderAst
|
||||
return m_expressionCategory;
|
||||
}
|
||||
|
||||
void ShaderAstValueCategory::Visit(AccessMemberExpression& node)
|
||||
void ShaderAstValueCategory::Visit(AccessMemberIdentifierExpression& node)
|
||||
{
|
||||
node.structExpr->Visit(*this);
|
||||
}
|
||||
|
||||
void ShaderAstValueCategory::Visit(AccessMemberIndexExpression& node)
|
||||
{
|
||||
node.structExpr->Visit(*this);
|
||||
}
|
||||
@@ -66,4 +71,9 @@ namespace Nz::ShaderAst
|
||||
{
|
||||
node.expression->Visit(*this);
|
||||
}
|
||||
|
||||
void ShaderAstValueCategory::Visit(VariableExpression& node)
|
||||
{
|
||||
m_expressionCategory = ExpressionCategory::LValue;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -131,7 +131,7 @@ namespace Nz::ShaderAst
|
||||
return expressionType;
|
||||
}
|
||||
|
||||
void AstValidator::Visit(AccessMemberExpression& node)
|
||||
void AstValidator::Visit(AccessMemberIdentifierExpression& node)
|
||||
{
|
||||
// Register expressions types
|
||||
AstScopedVisitor::Visit(node);
|
||||
@@ -351,7 +351,7 @@ namespace Nz::ShaderAst
|
||||
if (!exprPtr)
|
||||
break;
|
||||
|
||||
ExpressionType exprType = GetExpressionType(*exprPtr);
|
||||
const ExpressionType& exprType = GetExpressionType(*exprPtr);
|
||||
if (!IsPrimitiveType(exprType) && !IsVectorType(exprType))
|
||||
throw AstError{ "incompatible type" };
|
||||
|
||||
@@ -552,14 +552,17 @@ namespace Nz::ShaderAst
|
||||
|
||||
void AstValidator::Visit(DeclareExternalStatement& node)
|
||||
{
|
||||
for (const auto& [attributeType, arg] : node.attributes)
|
||||
if (!node.attributes.empty())
|
||||
throw AstError{ "unhandled attribute for external block" };
|
||||
|
||||
/*for (const auto& [attributeType, arg] : node.attributes)
|
||||
{
|
||||
switch (attributeType)
|
||||
{
|
||||
default:
|
||||
throw AstError{ "unhandled attribute for external block" };
|
||||
}
|
||||
}
|
||||
}*/
|
||||
|
||||
for (const auto& extVar : node.externalVars)
|
||||
{
|
||||
|
||||
@@ -602,7 +602,7 @@ namespace Nz::ShaderLang
|
||||
|
||||
if (Peek().type == TokenType::Dot)
|
||||
{
|
||||
std::unique_ptr<ShaderAst::AccessMemberExpression> accessMemberNode = std::make_unique<ShaderAst::AccessMemberExpression>();
|
||||
std::unique_ptr<ShaderAst::AccessMemberIdentifierExpression> accessMemberNode = std::make_unique<ShaderAst::AccessMemberIdentifierExpression>();
|
||||
accessMemberNode->structExpr = std::move(identifierExpr);
|
||||
|
||||
do
|
||||
@@ -685,9 +685,9 @@ namespace Nz::ShaderLang
|
||||
if (IsVariableInScope(identifier))
|
||||
{
|
||||
auto node = ParseIdentifier();
|
||||
if (node->GetType() == ShaderAst::NodeType::AccessMemberExpression)
|
||||
if (node->GetType() == ShaderAst::NodeType::AccessMemberIdentifierExpression)
|
||||
{
|
||||
ShaderAst::AccessMemberExpression* memberExpr = static_cast<ShaderAst::AccessMemberExpression*>(node.get());
|
||||
ShaderAst::AccessMemberIdentifierExpression* memberExpr = static_cast<ShaderAst::AccessMemberIdentifierExpression*>(node.get());
|
||||
if (!memberExpr->memberIdentifiers.empty() && memberExpr->memberIdentifiers.front() == "Sample")
|
||||
{
|
||||
if (Peek().type == TokenType::OpenParenthesis)
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
// For conditions of distribution and use, see copyright notice in Config.hpp
|
||||
|
||||
#include <Nazara/Shader/SpirvAstVisitor.hpp>
|
||||
#include <Nazara/Core/CallOnExit.hpp>
|
||||
#include <Nazara/Core/StackVector.hpp>
|
||||
#include <Nazara/Shader/SpirvSection.hpp>
|
||||
#include <Nazara/Shader/SpirvExpressionLoad.hpp>
|
||||
@@ -12,6 +13,11 @@
|
||||
|
||||
namespace Nz
|
||||
{
|
||||
UInt32 SpirvAstVisitor::AllocateResultId()
|
||||
{
|
||||
return m_writer.AllocateResultId();
|
||||
}
|
||||
|
||||
UInt32 SpirvAstVisitor::EvaluateExpression(ShaderAst::ExpressionPtr& expr)
|
||||
{
|
||||
expr->Visit(*this);
|
||||
@@ -20,9 +26,16 @@ namespace Nz
|
||||
return PopResultId();
|
||||
}
|
||||
|
||||
void SpirvAstVisitor::Visit(ShaderAst::AccessMemberExpression& node)
|
||||
auto SpirvAstVisitor::GetVariable(std::size_t varIndex) const -> const Variable&
|
||||
{
|
||||
SpirvExpressionLoad accessMemberVisitor(m_writer, *m_currentBlock);
|
||||
assert(varIndex < m_variables.size());
|
||||
assert(m_variables[varIndex]);
|
||||
return *m_variables[varIndex];
|
||||
}
|
||||
|
||||
void SpirvAstVisitor::Visit(ShaderAst::AccessMemberIndexExpression& node)
|
||||
{
|
||||
SpirvExpressionLoad accessMemberVisitor(m_writer, *this, *m_currentBlock);
|
||||
PushResultId(accessMemberVisitor.Evaluate(node));
|
||||
}
|
||||
|
||||
@@ -30,7 +43,7 @@ namespace Nz
|
||||
{
|
||||
UInt32 resultId = EvaluateExpression(node.right);
|
||||
|
||||
SpirvExpressionStore storeVisitor(m_writer, *m_currentBlock);
|
||||
SpirvExpressionStore storeVisitor(m_writer, *this, *m_currentBlock);
|
||||
storeVisitor.Store(node.left, resultId);
|
||||
|
||||
PushResultId(resultId);
|
||||
@@ -38,18 +51,24 @@ namespace Nz
|
||||
|
||||
void SpirvAstVisitor::Visit(ShaderAst::BinaryExpression& node)
|
||||
{
|
||||
ShaderAst::ExpressionType resultExprType = GetExpressionType(node);
|
||||
assert(IsPrimitiveType(resultExprType));
|
||||
auto RetrieveBaseType = [](const ShaderAst::ExpressionType& exprType)
|
||||
{
|
||||
if (IsPrimitiveType(exprType))
|
||||
return std::get<ShaderAst::PrimitiveType>(exprType);
|
||||
else if (IsVectorType(exprType))
|
||||
return std::get<ShaderAst::VectorType>(exprType).type;
|
||||
else if (IsMatrixType(exprType))
|
||||
return std::get<ShaderAst::MatrixType>(exprType).type;
|
||||
else
|
||||
throw std::runtime_error("unexpected type");
|
||||
};
|
||||
|
||||
ShaderAst::ExpressionType leftExprType = GetExpressionType(*node.left);
|
||||
assert(IsPrimitiveType(leftExprType));
|
||||
const ShaderAst::ExpressionType& resultType = GetExpressionType(node);
|
||||
const ShaderAst::ExpressionType& leftType = GetExpressionType(*node.left);
|
||||
const ShaderAst::ExpressionType& rightType = GetExpressionType(*node.right);
|
||||
|
||||
ShaderAst::ExpressionType rightExprType = GetExpressionType(*node.right);
|
||||
assert(IsPrimitiveType(rightExprType));
|
||||
|
||||
ShaderAst::PrimitiveType resultType = std::get<ShaderAst::PrimitiveType>(resultExprType);
|
||||
ShaderAst::PrimitiveType leftType = std::get<ShaderAst::PrimitiveType>(leftExprType);
|
||||
ShaderAst::PrimitiveType rightType = std::get<ShaderAst::PrimitiveType>(rightExprType);
|
||||
ShaderAst::PrimitiveType leftTypeBase = RetrieveBaseType(leftType);
|
||||
ShaderAst::PrimitiveType rightTypeBase = RetrieveBaseType(rightType);
|
||||
|
||||
|
||||
UInt32 leftOperand = EvaluateExpression(node.left);
|
||||
@@ -64,28 +83,16 @@ namespace Nz
|
||||
{
|
||||
case ShaderAst::BinaryType::Add:
|
||||
{
|
||||
switch (leftType)
|
||||
switch (leftTypeBase)
|
||||
{
|
||||
case ShaderAst::PrimitiveType::Float32:
|
||||
// case ShaderAst::PrimitiveType::Float2:
|
||||
// case ShaderAst::PrimitiveType::Float3:
|
||||
// case ShaderAst::PrimitiveType::Float4:
|
||||
// case ShaderAst::PrimitiveType::Mat4x4:
|
||||
return SpirvOp::OpFAdd;
|
||||
|
||||
case ShaderAst::PrimitiveType::Int32:
|
||||
// case ShaderAst::PrimitiveType::Int2:
|
||||
// case ShaderAst::PrimitiveType::Int3:
|
||||
// case ShaderAst::PrimitiveType::Int4:
|
||||
case ShaderAst::PrimitiveType::UInt32:
|
||||
// case ShaderAst::PrimitiveType::UInt2:
|
||||
// case ShaderAst::PrimitiveType::UInt3:
|
||||
// case ShaderAst::PrimitiveType::UInt4:
|
||||
return SpirvOp::OpIAdd;
|
||||
|
||||
case ShaderAst::PrimitiveType::Boolean:
|
||||
// case ShaderAst::PrimitiveType::Sampler2D:
|
||||
// case ShaderAst::PrimitiveType::Void:
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -94,28 +101,16 @@ namespace Nz
|
||||
|
||||
case ShaderAst::BinaryType::Subtract:
|
||||
{
|
||||
switch (leftType)
|
||||
switch (leftTypeBase)
|
||||
{
|
||||
case ShaderAst::PrimitiveType::Float32:
|
||||
// case ShaderAst::PrimitiveType::Float2:
|
||||
// case ShaderAst::PrimitiveType::Float3:
|
||||
// case ShaderAst::PrimitiveType::Float4:
|
||||
// case ShaderAst::PrimitiveType::Mat4x4:
|
||||
return SpirvOp::OpFSub;
|
||||
|
||||
case ShaderAst::PrimitiveType::Int32:
|
||||
// case ShaderAst::PrimitiveType::Int2:
|
||||
// case ShaderAst::PrimitiveType::Int3:
|
||||
// case ShaderAst::PrimitiveType::Int4:
|
||||
case ShaderAst::PrimitiveType::UInt32:
|
||||
// case ShaderAst::PrimitiveType::UInt2:
|
||||
// case ShaderAst::PrimitiveType::UInt3:
|
||||
// case ShaderAst::PrimitiveType::UInt4:
|
||||
return SpirvOp::OpISub;
|
||||
|
||||
case ShaderAst::PrimitiveType::Boolean:
|
||||
// case ShaderAst::PrimitiveType::Sampler2D:
|
||||
// case ShaderAst::PrimitiveType::Void:
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -124,30 +119,18 @@ namespace Nz
|
||||
|
||||
case ShaderAst::BinaryType::Divide:
|
||||
{
|
||||
switch (leftType)
|
||||
switch (leftTypeBase)
|
||||
{
|
||||
case ShaderAst::PrimitiveType::Float32:
|
||||
// case ShaderAst::PrimitiveType::Float2:
|
||||
// case ShaderAst::PrimitiveType::Float3:
|
||||
// case ShaderAst::PrimitiveType::Float4:
|
||||
// case ShaderAst::PrimitiveType::Mat4x4:
|
||||
return SpirvOp::OpFDiv;
|
||||
|
||||
case ShaderAst::PrimitiveType::Int32:
|
||||
// case ShaderAst::PrimitiveType::Int2:
|
||||
// case ShaderAst::PrimitiveType::Int3:
|
||||
// case ShaderAst::PrimitiveType::Int4:
|
||||
return SpirvOp::OpSDiv;
|
||||
|
||||
case ShaderAst::PrimitiveType::UInt32:
|
||||
// case ShaderAst::PrimitiveType::UInt2:
|
||||
// case ShaderAst::PrimitiveType::UInt3:
|
||||
// case ShaderAst::PrimitiveType::UInt4:
|
||||
return SpirvOp::OpUDiv;
|
||||
|
||||
case ShaderAst::PrimitiveType::Boolean:
|
||||
// case ShaderAst::PrimitiveType::Sampler2D:
|
||||
// case ShaderAst::PrimitiveType::Void:
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -156,31 +139,17 @@ namespace Nz
|
||||
|
||||
case ShaderAst::BinaryType::CompEq:
|
||||
{
|
||||
switch (leftType)
|
||||
switch (leftTypeBase)
|
||||
{
|
||||
case ShaderAst::PrimitiveType::Boolean:
|
||||
return SpirvOp::OpLogicalEqual;
|
||||
|
||||
case ShaderAst::PrimitiveType::Float32:
|
||||
// case ShaderAst::PrimitiveType::Float2:
|
||||
// case ShaderAst::PrimitiveType::Float3:
|
||||
// case ShaderAst::PrimitiveType::Float4:
|
||||
// case ShaderAst::PrimitiveType::Mat4x4:
|
||||
return SpirvOp::OpFOrdEqual;
|
||||
|
||||
case ShaderAst::PrimitiveType::Int32:
|
||||
// case ShaderAst::PrimitiveType::Int2:
|
||||
// case ShaderAst::PrimitiveType::Int3:
|
||||
// case ShaderAst::PrimitiveType::Int4:
|
||||
case ShaderAst::PrimitiveType::UInt32:
|
||||
// case ShaderAst::PrimitiveType::UInt2:
|
||||
// case ShaderAst::PrimitiveType::UInt3:
|
||||
// case ShaderAst::PrimitiveType::UInt4:
|
||||
return SpirvOp::OpIEqual;
|
||||
|
||||
// case ShaderAst::PrimitiveType::Sampler2D:
|
||||
// case ShaderAst::PrimitiveType::Void:
|
||||
// break;
|
||||
}
|
||||
|
||||
break;
|
||||
@@ -188,30 +157,18 @@ namespace Nz
|
||||
|
||||
case ShaderAst::BinaryType::CompGe:
|
||||
{
|
||||
switch (leftType)
|
||||
switch (leftTypeBase)
|
||||
{
|
||||
case ShaderAst::PrimitiveType::Float32:
|
||||
// case ShaderAst::PrimitiveType::Float2:
|
||||
// case ShaderAst::PrimitiveType::Float3:
|
||||
// case ShaderAst::PrimitiveType::Float4:
|
||||
// case ShaderAst::PrimitiveType::Mat4x4:
|
||||
return SpirvOp::OpFOrdGreaterThan;
|
||||
|
||||
case ShaderAst::PrimitiveType::Int32:
|
||||
// case ShaderAst::PrimitiveType::Int2:
|
||||
// case ShaderAst::PrimitiveType::Int3:
|
||||
// case ShaderAst::PrimitiveType::Int4:
|
||||
return SpirvOp::OpSGreaterThan;
|
||||
|
||||
case ShaderAst::PrimitiveType::UInt32:
|
||||
// case ShaderAst::PrimitiveType::UInt2:
|
||||
// case ShaderAst::PrimitiveType::UInt3:
|
||||
// case ShaderAst::PrimitiveType::UInt4:
|
||||
return SpirvOp::OpUGreaterThan;
|
||||
|
||||
case ShaderAst::PrimitiveType::Boolean:
|
||||
// case ShaderAst::PrimitiveType::Sampler2D:
|
||||
// case ShaderAst::PrimitiveType::Void:
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -220,30 +177,18 @@ namespace Nz
|
||||
|
||||
case ShaderAst::BinaryType::CompGt:
|
||||
{
|
||||
switch (leftType)
|
||||
switch (leftTypeBase)
|
||||
{
|
||||
case ShaderAst::PrimitiveType::Float32:
|
||||
// case ShaderAst::PrimitiveType::Float2:
|
||||
// case ShaderAst::PrimitiveType::Float3:
|
||||
// case ShaderAst::PrimitiveType::Float4:
|
||||
// case ShaderAst::PrimitiveType::Mat4x4:
|
||||
return SpirvOp::OpFOrdGreaterThanEqual;
|
||||
|
||||
case ShaderAst::PrimitiveType::Int32:
|
||||
// case ShaderAst::PrimitiveType::Int2:
|
||||
// case ShaderAst::PrimitiveType::Int3:
|
||||
// case ShaderAst::PrimitiveType::Int4:
|
||||
return SpirvOp::OpSGreaterThanEqual;
|
||||
|
||||
case ShaderAst::PrimitiveType::UInt32:
|
||||
// case ShaderAst::PrimitiveType::UInt2:
|
||||
// case ShaderAst::PrimitiveType::UInt3:
|
||||
// case ShaderAst::PrimitiveType::UInt4:
|
||||
return SpirvOp::OpUGreaterThanEqual;
|
||||
|
||||
case ShaderAst::PrimitiveType::Boolean:
|
||||
// case ShaderAst::PrimitiveType::Sampler2D:
|
||||
// case ShaderAst::PrimitiveType::Void:
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -252,30 +197,18 @@ namespace Nz
|
||||
|
||||
case ShaderAst::BinaryType::CompLe:
|
||||
{
|
||||
switch (leftType)
|
||||
switch (leftTypeBase)
|
||||
{
|
||||
case ShaderAst::PrimitiveType::Float32:
|
||||
// case ShaderAst::PrimitiveType::Float2:
|
||||
// case ShaderAst::PrimitiveType::Float3:
|
||||
// case ShaderAst::PrimitiveType::Float4:
|
||||
// case ShaderAst::PrimitiveType::Mat4x4:
|
||||
return SpirvOp::OpFOrdLessThanEqual;
|
||||
|
||||
case ShaderAst::PrimitiveType::Int32:
|
||||
// case ShaderAst::PrimitiveType::Int2:
|
||||
// case ShaderAst::PrimitiveType::Int3:
|
||||
// case ShaderAst::PrimitiveType::Int4:
|
||||
return SpirvOp::OpSLessThanEqual;
|
||||
|
||||
case ShaderAst::PrimitiveType::UInt32:
|
||||
// case ShaderAst::PrimitiveType::UInt2:
|
||||
// case ShaderAst::PrimitiveType::UInt3:
|
||||
// case ShaderAst::PrimitiveType::UInt4:
|
||||
return SpirvOp::OpULessThanEqual;
|
||||
|
||||
case ShaderAst::PrimitiveType::Boolean:
|
||||
// case ShaderAst::PrimitiveType::Sampler2D:
|
||||
// case ShaderAst::PrimitiveType::Void:
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -284,30 +217,18 @@ namespace Nz
|
||||
|
||||
case ShaderAst::BinaryType::CompLt:
|
||||
{
|
||||
switch (leftType)
|
||||
switch (leftTypeBase)
|
||||
{
|
||||
case ShaderAst::PrimitiveType::Float32:
|
||||
// case ShaderAst::PrimitiveType::Float2:
|
||||
// case ShaderAst::PrimitiveType::Float3:
|
||||
// case ShaderAst::PrimitiveType::Float4:
|
||||
// case ShaderAst::PrimitiveType::Mat4x4:
|
||||
return SpirvOp::OpFOrdLessThan;
|
||||
|
||||
case ShaderAst::PrimitiveType::Int32:
|
||||
// case ShaderAst::PrimitiveType::Int2:
|
||||
// case ShaderAst::PrimitiveType::Int3:
|
||||
// case ShaderAst::PrimitiveType::Int4:
|
||||
return SpirvOp::OpSLessThan;
|
||||
|
||||
case ShaderAst::PrimitiveType::UInt32:
|
||||
// case ShaderAst::PrimitiveType::UInt2:
|
||||
// case ShaderAst::PrimitiveType::UInt3:
|
||||
// case ShaderAst::PrimitiveType::UInt4:
|
||||
return SpirvOp::OpULessThan;
|
||||
|
||||
case ShaderAst::PrimitiveType::Boolean:
|
||||
// case ShaderAst::PrimitiveType::Sampler2D:
|
||||
// case ShaderAst::PrimitiveType::Void:
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -316,31 +237,17 @@ namespace Nz
|
||||
|
||||
case ShaderAst::BinaryType::CompNe:
|
||||
{
|
||||
switch (leftType)
|
||||
switch (leftTypeBase)
|
||||
{
|
||||
case ShaderAst::PrimitiveType::Boolean:
|
||||
return SpirvOp::OpLogicalNotEqual;
|
||||
|
||||
case ShaderAst::PrimitiveType::Float32:
|
||||
// case ShaderAst::PrimitiveType::Float2:
|
||||
// case ShaderAst::PrimitiveType::Float3:
|
||||
// case ShaderAst::PrimitiveType::Float4:
|
||||
// case ShaderAst::PrimitiveType::Mat4x4:
|
||||
return SpirvOp::OpFOrdNotEqual;
|
||||
|
||||
case ShaderAst::PrimitiveType::Int32:
|
||||
// case ShaderAst::PrimitiveType::Int2:
|
||||
// case ShaderAst::PrimitiveType::Int3:
|
||||
// case ShaderAst::PrimitiveType::Int4:
|
||||
case ShaderAst::PrimitiveType::UInt32:
|
||||
// case ShaderAst::PrimitiveType::UInt2:
|
||||
// case ShaderAst::PrimitiveType::UInt3:
|
||||
// case ShaderAst::PrimitiveType::UInt4:
|
||||
return SpirvOp::OpINotEqual;
|
||||
|
||||
// case ShaderAst::PrimitiveType::Sampler2D:
|
||||
// case ShaderAst::PrimitiveType::Void:
|
||||
// break;
|
||||
}
|
||||
|
||||
break;
|
||||
@@ -348,81 +255,51 @@ namespace Nz
|
||||
|
||||
case ShaderAst::BinaryType::Multiply:
|
||||
{
|
||||
switch (leftType)
|
||||
switch (leftTypeBase)
|
||||
{
|
||||
case ShaderAst::PrimitiveType::Float32:
|
||||
{
|
||||
switch (rightType)
|
||||
if (IsPrimitiveType(leftType))
|
||||
{
|
||||
case ShaderAst::PrimitiveType::Float32:
|
||||
return SpirvOp::OpFMul;
|
||||
|
||||
// case ShaderAst::PrimitiveType::Float2:
|
||||
// case ShaderAst::PrimitiveType::Float3:
|
||||
// case ShaderAst::PrimitiveType::Float4:
|
||||
// swapOperands = true;
|
||||
// return SpirvOp::OpVectorTimesScalar;
|
||||
//
|
||||
// case ShaderAst::PrimitiveType::Mat4x4:
|
||||
// swapOperands = true;
|
||||
// return SpirvOp::OpMatrixTimesScalar;
|
||||
|
||||
default:
|
||||
break;
|
||||
// Handle float * matrix|vector as matrix|vector * float
|
||||
if (IsMatrixType(rightType))
|
||||
{
|
||||
swapOperands = true;
|
||||
return SpirvOp::OpMatrixTimesScalar;
|
||||
}
|
||||
else if (IsVectorType(rightType))
|
||||
{
|
||||
swapOperands = true;
|
||||
return SpirvOp::OpVectorTimesScalar;
|
||||
}
|
||||
}
|
||||
else if (IsPrimitiveType(rightType))
|
||||
{
|
||||
if (IsMatrixType(leftType))
|
||||
return SpirvOp::OpMatrixTimesScalar;
|
||||
else if (IsVectorType(leftType))
|
||||
return SpirvOp::OpVectorTimesScalar;
|
||||
}
|
||||
else if (IsMatrixType(leftType))
|
||||
{
|
||||
if (IsMatrixType(rightType))
|
||||
return SpirvOp::OpMatrixTimesMatrix;
|
||||
else if (IsVectorType(rightType))
|
||||
return SpirvOp::OpMatrixTimesVector;
|
||||
}
|
||||
else if (IsMatrixType(rightType))
|
||||
{
|
||||
assert(IsVectorType(leftType));
|
||||
return SpirvOp::OpVectorTimesMatrix;
|
||||
}
|
||||
|
||||
break;
|
||||
return SpirvOp::OpFMul;
|
||||
}
|
||||
|
||||
// case ShaderAst::PrimitiveType::Float2:
|
||||
// case ShaderAst::PrimitiveType::Float3:
|
||||
// case ShaderAst::PrimitiveType::Float4:
|
||||
// {
|
||||
// switch (rightType)
|
||||
// {
|
||||
// case ShaderAst::PrimitiveType::Float32:
|
||||
// return SpirvOp::OpVectorTimesScalar;
|
||||
//
|
||||
// case ShaderAst::PrimitiveType::Float2:
|
||||
// case ShaderAst::PrimitiveType::Float3:
|
||||
// case ShaderAst::PrimitiveType::Float4:
|
||||
// return SpirvOp::OpFMul;
|
||||
//
|
||||
// case ShaderAst::PrimitiveType::Mat4x4:
|
||||
// return SpirvOp::OpVectorTimesMatrix;
|
||||
//
|
||||
// default:
|
||||
// break;
|
||||
// }
|
||||
//
|
||||
// break;
|
||||
// }
|
||||
|
||||
case ShaderAst::PrimitiveType::Int32:
|
||||
// case ShaderAst::PrimitiveType::Int2:
|
||||
// case ShaderAst::PrimitiveType::Int3:
|
||||
// case ShaderAst::PrimitiveType::Int4:
|
||||
case ShaderAst::PrimitiveType::UInt32:
|
||||
// case ShaderAst::PrimitiveType::UInt2:
|
||||
// case ShaderAst::PrimitiveType::UInt3:
|
||||
// case ShaderAst::PrimitiveType::UInt4:
|
||||
return SpirvOp::OpIMul;
|
||||
|
||||
// case ShaderAst::PrimitiveType::Mat4x4:
|
||||
// {
|
||||
// switch (rightType)
|
||||
// {
|
||||
// case ShaderAst::PrimitiveType::Float32: return SpirvOp::OpMatrixTimesScalar;
|
||||
// case ShaderAst::PrimitiveType::Float4: return SpirvOp::OpMatrixTimesVector;
|
||||
// case ShaderAst::PrimitiveType::Mat4x4: return SpirvOp::OpMatrixTimesMatrix;
|
||||
//
|
||||
// default:
|
||||
// break;
|
||||
// }
|
||||
//
|
||||
// break;
|
||||
// }
|
||||
|
||||
default:
|
||||
break;
|
||||
}
|
||||
@@ -454,7 +331,7 @@ namespace Nz
|
||||
firstCond.statement->Visit(*this);
|
||||
|
||||
SpirvBlock mergeBlock(m_writer);
|
||||
m_blocks.back().Append(SpirvOp::OpSelectionMerge, mergeBlock.GetLabelId(), SpirvSelectionControl::None);
|
||||
m_functionBlocks.back().Append(SpirvOp::OpSelectionMerge, mergeBlock.GetLabelId(), SpirvSelectionControl::None);
|
||||
|
||||
std::optional<std::size_t> nextBlock;
|
||||
for (std::size_t statementIndex = 1; statementIndex < node.condStatements.size(); ++statementIndex)
|
||||
@@ -463,10 +340,10 @@ namespace Nz
|
||||
|
||||
SpirvBlock contentBlock(m_writer);
|
||||
|
||||
m_blocks.back().Append(SpirvOp::OpBranchConditional, previousConditionId, previousContentBlock.GetLabelId(), contentBlock.GetLabelId());
|
||||
m_functionBlocks.back().Append(SpirvOp::OpBranchConditional, previousConditionId, previousContentBlock.GetLabelId(), contentBlock.GetLabelId());
|
||||
|
||||
previousConditionId = EvaluateExpression(statement.condition);
|
||||
m_blocks.emplace_back(std::move(previousContentBlock));
|
||||
m_functionBlocks.emplace_back(std::move(previousContentBlock));
|
||||
previousContentBlock = std::move(contentBlock);
|
||||
|
||||
m_currentBlock = &previousContentBlock;
|
||||
@@ -479,54 +356,148 @@ namespace Nz
|
||||
SpirvBlock elseBlock(m_writer);
|
||||
|
||||
m_currentBlock = &elseBlock;
|
||||
|
||||
node.elseStatement->Visit(*this);
|
||||
|
||||
elseBlock.Append(SpirvOp::OpBranch, mergeBlock.GetLabelId()); //< FIXME: Shouldn't terminate twice
|
||||
|
||||
m_blocks.back().Append(SpirvOp::OpBranchConditional, previousConditionId, previousContentBlock.GetLabelId(), elseBlock.GetLabelId());
|
||||
m_blocks.emplace_back(std::move(previousContentBlock));
|
||||
m_blocks.emplace_back(std::move(elseBlock));
|
||||
m_functionBlocks.back().Append(SpirvOp::OpBranchConditional, previousConditionId, previousContentBlock.GetLabelId(), elseBlock.GetLabelId());
|
||||
m_functionBlocks.emplace_back(std::move(previousContentBlock));
|
||||
m_functionBlocks.emplace_back(std::move(elseBlock));
|
||||
}
|
||||
else
|
||||
{
|
||||
m_blocks.back().Append(SpirvOp::OpBranchConditional, previousConditionId, previousContentBlock.GetLabelId(), mergeBlock.GetLabelId());
|
||||
m_blocks.emplace_back(std::move(previousContentBlock));
|
||||
m_functionBlocks.back().Append(SpirvOp::OpBranchConditional, previousConditionId, previousContentBlock.GetLabelId(), mergeBlock.GetLabelId());
|
||||
m_functionBlocks.emplace_back(std::move(previousContentBlock));
|
||||
}
|
||||
|
||||
m_blocks.emplace_back(std::move(mergeBlock));
|
||||
m_functionBlocks.emplace_back(std::move(mergeBlock));
|
||||
|
||||
m_currentBlock = &m_blocks.back();
|
||||
m_currentBlock = &m_functionBlocks.back();
|
||||
}
|
||||
|
||||
void SpirvAstVisitor::Visit(ShaderAst::CastExpression& node)
|
||||
{
|
||||
const ShaderAst::ExpressionType& targetExprType = node.targetType;
|
||||
assert(IsPrimitiveType(targetExprType));
|
||||
|
||||
ShaderAst::PrimitiveType targetType = std::get<ShaderAst::PrimitiveType>(targetExprType);
|
||||
|
||||
StackVector<UInt32> exprResults = NazaraStackVector(UInt32, node.expressions.size());
|
||||
|
||||
for (auto& exprPtr : node.expressions)
|
||||
if (IsPrimitiveType(targetExprType))
|
||||
{
|
||||
if (!exprPtr)
|
||||
break;
|
||||
ShaderAst::PrimitiveType targetType = std::get<ShaderAst::PrimitiveType>(targetExprType);
|
||||
|
||||
exprResults.push_back(EvaluateExpression(exprPtr));
|
||||
assert(node.expressions[0] && !node.expressions[1]);
|
||||
ShaderAst::ExpressionPtr& expression = node.expressions[0];
|
||||
|
||||
assert(expression->cachedExpressionType.has_value());
|
||||
const ShaderAst::ExpressionType& exprType = expression->cachedExpressionType.value();
|
||||
assert(IsPrimitiveType(exprType));
|
||||
ShaderAst::PrimitiveType fromType = std::get<ShaderAst::PrimitiveType>(exprType);
|
||||
|
||||
UInt32 fromId = EvaluateExpression(expression);
|
||||
if (targetType == fromType)
|
||||
return PushResultId(fromId);
|
||||
|
||||
std::optional<SpirvOp> castOp;
|
||||
switch (targetType)
|
||||
{
|
||||
case ShaderAst::PrimitiveType::Boolean:
|
||||
throw std::runtime_error("unsupported cast to boolean");
|
||||
|
||||
case ShaderAst::PrimitiveType::Float32:
|
||||
{
|
||||
switch (fromType)
|
||||
{
|
||||
case ShaderAst::PrimitiveType::Boolean:
|
||||
throw std::runtime_error("unsupported cast from boolean");
|
||||
|
||||
case ShaderAst::PrimitiveType::Float32:
|
||||
break; //< Already handled
|
||||
|
||||
case ShaderAst::PrimitiveType::Int32:
|
||||
castOp = SpirvOp::OpConvertSToF;
|
||||
break;
|
||||
|
||||
case ShaderAst::PrimitiveType::UInt32:
|
||||
castOp = SpirvOp::OpConvertUToF;
|
||||
break;
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
case ShaderAst::PrimitiveType::Int32:
|
||||
{
|
||||
switch (fromType)
|
||||
{
|
||||
case ShaderAst::PrimitiveType::Boolean:
|
||||
throw std::runtime_error("unsupported cast from boolean");
|
||||
|
||||
case ShaderAst::PrimitiveType::Float32:
|
||||
castOp = SpirvOp::OpConvertFToS;
|
||||
break;
|
||||
|
||||
case ShaderAst::PrimitiveType::Int32:
|
||||
break; //< Already handled
|
||||
|
||||
case ShaderAst::PrimitiveType::UInt32:
|
||||
castOp = SpirvOp::OpSConvert;
|
||||
break;
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
case ShaderAst::PrimitiveType::UInt32:
|
||||
{
|
||||
switch (fromType)
|
||||
{
|
||||
case ShaderAst::PrimitiveType::Boolean:
|
||||
throw std::runtime_error("unsupported cast from boolean");
|
||||
|
||||
case ShaderAst::PrimitiveType::Float32:
|
||||
castOp = SpirvOp::OpConvertFToU;
|
||||
break;
|
||||
|
||||
case ShaderAst::PrimitiveType::Int32:
|
||||
castOp = SpirvOp::OpUConvert;
|
||||
break;
|
||||
|
||||
case ShaderAst::PrimitiveType::UInt32:
|
||||
break; //< Already handled
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
assert(castOp);
|
||||
|
||||
UInt32 resultId = m_writer.AllocateResultId();
|
||||
m_currentBlock->Append(*castOp, m_writer.GetTypeId(targetType), resultId, fromId);
|
||||
|
||||
throw std::runtime_error("toudou");
|
||||
}
|
||||
|
||||
UInt32 resultId = m_writer.AllocateResultId();
|
||||
|
||||
m_currentBlock->AppendVariadic(SpirvOp::OpCompositeConstruct, [&](const auto& appender)
|
||||
else
|
||||
{
|
||||
appender(m_writer.GetTypeId(targetType));
|
||||
appender(resultId);
|
||||
assert(IsVectorType(targetExprType));
|
||||
StackVector<UInt32> exprResults = NazaraStackVector(UInt32, node.expressions.size());
|
||||
|
||||
for (UInt32 exprResultId : exprResults)
|
||||
appender(exprResultId);
|
||||
});
|
||||
for (auto& exprPtr : node.expressions)
|
||||
{
|
||||
if (!exprPtr)
|
||||
break;
|
||||
|
||||
PushResultId(resultId);
|
||||
exprResults.push_back(EvaluateExpression(exprPtr));
|
||||
}
|
||||
|
||||
UInt32 resultId = m_writer.AllocateResultId();
|
||||
|
||||
m_currentBlock->AppendVariadic(SpirvOp::OpCompositeConstruct, [&](const auto& appender)
|
||||
{
|
||||
appender(m_writer.GetTypeId(targetExprType));
|
||||
appender(resultId);
|
||||
|
||||
for (UInt32 exprResultId : exprResults)
|
||||
appender(exprResultId);
|
||||
});
|
||||
|
||||
PushResultId(resultId);
|
||||
}
|
||||
}
|
||||
|
||||
void SpirvAstVisitor::Visit(ShaderAst::ConditionalExpression& node)
|
||||
@@ -551,10 +522,108 @@ namespace Nz
|
||||
}, node.value);
|
||||
}
|
||||
|
||||
void SpirvAstVisitor::Visit(ShaderAst::DeclareExternalStatement& node)
|
||||
{
|
||||
assert(node.varIndex);
|
||||
|
||||
std::size_t varIndex = *node.varIndex;
|
||||
for (auto&& extVar : node.externalVars)
|
||||
RegisterExternalVariable(varIndex++, extVar.type);
|
||||
}
|
||||
|
||||
void SpirvAstVisitor::Visit(ShaderAst::DeclareFunctionStatement& node)
|
||||
{
|
||||
assert(node.funcIndex);
|
||||
m_funcIndex = *node.funcIndex;
|
||||
|
||||
auto& func = m_funcData[m_funcIndex];
|
||||
func.funcId = m_writer.AllocateResultId();
|
||||
|
||||
m_instructions.Append(SpirvOp::OpFunction, func.returnTypeId, func.funcId, 0, func.funcTypeId);
|
||||
|
||||
if (!func.parameters.empty())
|
||||
{
|
||||
std::size_t varIndex = *node.varIndex;
|
||||
for (const auto& param : func.parameters)
|
||||
{
|
||||
UInt32 paramResultId = m_writer.AllocateResultId();
|
||||
m_instructions.Append(SpirvOp::OpFunctionParameter, param.typeId, paramResultId);
|
||||
|
||||
RegisterVariable(varIndex++, param.typeId, paramResultId, SpirvStorageClass::Function);
|
||||
}
|
||||
}
|
||||
|
||||
m_functionBlocks.clear();
|
||||
|
||||
m_currentBlock = &m_functionBlocks.emplace_back(m_writer);
|
||||
CallOnExit resetCurrentBlock([&] { m_currentBlock = nullptr; });
|
||||
|
||||
for (auto& var : func.variables)
|
||||
{
|
||||
var.varId = m_writer.AllocateResultId();
|
||||
m_currentBlock->Append(SpirvOp::OpVariable, var.typeId, var.varId, SpirvStorageClass::Function);
|
||||
}
|
||||
|
||||
if (func.entryPointData)
|
||||
{
|
||||
auto& entryPointData = *func.entryPointData;
|
||||
if (entryPointData.inputStruct)
|
||||
{
|
||||
auto& inputStruct = *entryPointData.inputStruct;
|
||||
|
||||
std::size_t varIndex = *node.varIndex;
|
||||
|
||||
UInt32 paramId = m_writer.AllocateResultId();
|
||||
m_currentBlock->Append(SpirvOp::OpVariable, inputStruct.pointerId, paramId, SpirvStorageClass::Function);
|
||||
|
||||
for (const auto& input : entryPointData.inputs)
|
||||
{
|
||||
UInt32 resultId = m_writer.AllocateResultId();
|
||||
m_currentBlock->Append(SpirvOp::OpAccessChain, input.memberPointerId, resultId, paramId, input.memberIndexConstantId);
|
||||
m_currentBlock->Append(SpirvOp::OpCopyMemory, resultId, input.varId);
|
||||
}
|
||||
|
||||
RegisterVariable(varIndex, inputStruct.typeId, paramId, SpirvStorageClass::Function);
|
||||
}
|
||||
}
|
||||
|
||||
for (auto& statementPtr : node.statements)
|
||||
statementPtr->Visit(*this);
|
||||
|
||||
// Add implicit return
|
||||
if (!m_functionBlocks.back().IsTerminated())
|
||||
m_functionBlocks.back().Append(SpirvOp::OpReturn);
|
||||
|
||||
for (SpirvBlock& block : m_functionBlocks)
|
||||
m_instructions.AppendSection(block);
|
||||
|
||||
m_instructions.Append(SpirvOp::OpFunctionEnd);
|
||||
}
|
||||
|
||||
void SpirvAstVisitor::Visit(ShaderAst::DeclareStructStatement& node)
|
||||
{
|
||||
assert(node.structIndex);
|
||||
RegisterStruct(*node.structIndex, node.description);
|
||||
}
|
||||
|
||||
void SpirvAstVisitor::Visit(ShaderAst::DeclareVariableStatement& node)
|
||||
{
|
||||
const auto& func = m_funcData[m_funcIndex];
|
||||
|
||||
UInt32 pointerTypeId = m_writer.GetPointerTypeId(node.varType, SpirvStorageClass::Function);
|
||||
UInt32 typeId = m_writer.GetTypeId(node.varType);
|
||||
|
||||
assert(node.varIndex);
|
||||
auto varIt = func.varIndexToVarId.find(*node.varIndex);
|
||||
UInt32 varId = func.variables[varIt->second].varId;
|
||||
|
||||
RegisterVariable(*node.varIndex, typeId, varId, SpirvStorageClass::Function);
|
||||
|
||||
if (node.initialExpression)
|
||||
m_writer.WriteLocalVariable(node.varName, EvaluateExpression(node.initialExpression));
|
||||
{
|
||||
UInt32 value = EvaluateExpression(node.initialExpression);
|
||||
m_currentBlock->Append(SpirvOp::OpStore, varId, value);
|
||||
}
|
||||
}
|
||||
|
||||
void SpirvAstVisitor::Visit(ShaderAst::DiscardStatement& /*node*/)
|
||||
@@ -569,19 +638,13 @@ namespace Nz
|
||||
PopResultId();
|
||||
}
|
||||
|
||||
void SpirvAstVisitor::Visit(ShaderAst::IdentifierExpression& node)
|
||||
{
|
||||
SpirvExpressionLoad loadVisitor(m_writer, *m_currentBlock);
|
||||
PushResultId(loadVisitor.Evaluate(node));
|
||||
}
|
||||
|
||||
void SpirvAstVisitor::Visit(ShaderAst::IntrinsicExpression& node)
|
||||
{
|
||||
switch (node.intrinsic)
|
||||
{
|
||||
case ShaderAst::IntrinsicType::DotProduct:
|
||||
{
|
||||
ShaderAst::ExpressionType vecExprType = GetExpressionType(*node.parameters[0]);
|
||||
const ShaderAst::ExpressionType& vecExprType = GetExpressionType(*node.parameters[0]);
|
||||
assert(IsVectorType(vecExprType));
|
||||
|
||||
const ShaderAst::VectorType& vecType = std::get<ShaderAst::VectorType>(vecExprType);
|
||||
@@ -598,6 +661,19 @@ namespace Nz
|
||||
break;
|
||||
}
|
||||
|
||||
case ShaderAst::IntrinsicType::SampleTexture:
|
||||
{
|
||||
UInt32 typeId = m_writer.GetTypeId(ShaderAst::VectorType{4, ShaderAst::PrimitiveType::Float32});
|
||||
|
||||
UInt32 samplerId = EvaluateExpression(node.parameters[0]);
|
||||
UInt32 coordinatesId = EvaluateExpression(node.parameters[1]);
|
||||
UInt32 resultId = m_writer.AllocateResultId();
|
||||
|
||||
m_currentBlock->Append(SpirvOp::OpImageSampleImplicitLod, typeId, resultId, samplerId, coordinatesId);
|
||||
PushResultId(resultId);
|
||||
break;
|
||||
}
|
||||
|
||||
case ShaderAst::IntrinsicType::CrossProduct:
|
||||
default:
|
||||
throw std::runtime_error("not yet implemented");
|
||||
@@ -609,23 +685,44 @@ namespace Nz
|
||||
// nothing to do
|
||||
}
|
||||
|
||||
void SpirvAstVisitor::Visit(ShaderAst::ReturnStatement& node)
|
||||
{
|
||||
if (node.returnExpr)
|
||||
m_currentBlock->Append(SpirvOp::OpReturnValue, EvaluateExpression(node.returnExpr));
|
||||
else
|
||||
m_currentBlock->Append(SpirvOp::OpReturn);
|
||||
}
|
||||
|
||||
void SpirvAstVisitor::Visit(ShaderAst::MultiStatement& node)
|
||||
{
|
||||
for (auto& statement : node.statements)
|
||||
statement->Visit(*this);
|
||||
}
|
||||
|
||||
void SpirvAstVisitor::Visit(ShaderAst::ReturnStatement& node)
|
||||
{
|
||||
if (node.returnExpr)
|
||||
{
|
||||
// Handle entry point return
|
||||
const auto& func = m_funcData[m_funcIndex];
|
||||
if (func.entryPointData)
|
||||
{
|
||||
auto& entryPointData = *func.entryPointData;
|
||||
if (entryPointData.outputStructTypeId)
|
||||
{
|
||||
UInt32 paramId = EvaluateExpression(node.returnExpr);
|
||||
for (const auto& output : entryPointData.outputs)
|
||||
{
|
||||
UInt32 resultId = m_writer.AllocateResultId();
|
||||
m_currentBlock->Append(SpirvOp::OpCompositeExtract, output.typeId, resultId, paramId, output.memberIndex);
|
||||
m_currentBlock->Append(SpirvOp::OpStore, output.varId, resultId);
|
||||
}
|
||||
}
|
||||
|
||||
m_currentBlock->Append(SpirvOp::OpReturn);
|
||||
}
|
||||
else
|
||||
m_currentBlock->Append(SpirvOp::OpReturnValue, EvaluateExpression(node.returnExpr));
|
||||
}
|
||||
else
|
||||
m_currentBlock->Append(SpirvOp::OpReturn);
|
||||
}
|
||||
|
||||
void SpirvAstVisitor::Visit(ShaderAst::SwizzleExpression& node)
|
||||
{
|
||||
ShaderAst::ExpressionType targetExprType = GetExpressionType(node);
|
||||
const ShaderAst::ExpressionType& targetExprType = GetExpressionType(node);
|
||||
assert(IsPrimitiveType(targetExprType));
|
||||
|
||||
ShaderAst::PrimitiveType targetType = std::get<ShaderAst::PrimitiveType>(targetExprType);
|
||||
@@ -658,6 +755,12 @@ namespace Nz
|
||||
PushResultId(resultId);
|
||||
}
|
||||
|
||||
void SpirvAstVisitor::Visit(ShaderAst::VariableExpression& node)
|
||||
{
|
||||
SpirvExpressionLoad loadVisitor(m_writer, *this, *m_currentBlock);
|
||||
PushResultId(loadVisitor.Evaluate(node));
|
||||
}
|
||||
|
||||
void SpirvAstVisitor::PushResultId(UInt32 value)
|
||||
{
|
||||
m_resultIds.push_back(value);
|
||||
|
||||
@@ -50,11 +50,6 @@ namespace Nz
|
||||
return Compare(lhs.parameters, rhs.parameters) && Compare(lhs.returnType, rhs.returnType);
|
||||
}
|
||||
|
||||
bool Compare(const Identifier& lhs, const Identifier& rhs) const
|
||||
{
|
||||
return lhs.name == rhs.name;
|
||||
}
|
||||
|
||||
bool Compare(const Image& lhs, const Image& rhs) const
|
||||
{
|
||||
return lhs.arrayed == rhs.arrayed
|
||||
@@ -114,6 +109,9 @@ namespace Nz
|
||||
if (lhs.debugName != rhs.debugName)
|
||||
return false;
|
||||
|
||||
if (lhs.funcId != rhs.funcId)
|
||||
return false;
|
||||
|
||||
if (!Compare(lhs.initializer, rhs.initializer))
|
||||
return false;
|
||||
|
||||
@@ -231,11 +229,6 @@ namespace Nz
|
||||
void Register(const Integer&) {}
|
||||
void Register(const Void&) {}
|
||||
|
||||
void Register(const Identifier& identifier)
|
||||
{
|
||||
Register(identifier);
|
||||
}
|
||||
|
||||
void Register(const Image& image)
|
||||
{
|
||||
Register(image.sampledType);
|
||||
@@ -406,6 +399,7 @@ namespace Nz
|
||||
tsl::ordered_map<std::variant<AnyConstant, AnyType>, UInt32 /*id*/, AnyHasher, Eq> ids;
|
||||
tsl::ordered_map<Variable, UInt32 /*id*/, AnyHasher, Eq> variableIds;
|
||||
tsl::ordered_map<Structure, FieldOffsets /*fieldOffsets*/, AnyHasher, Eq> structureSizes;
|
||||
StructCallback structCallback;
|
||||
UInt32& nextResultId;
|
||||
};
|
||||
|
||||
@@ -417,132 +411,8 @@ namespace Nz
|
||||
SpirvConstantCache::SpirvConstantCache(SpirvConstantCache&& cache) noexcept = default;
|
||||
|
||||
SpirvConstantCache::~SpirvConstantCache() = default;
|
||||
|
||||
UInt32 SpirvConstantCache::GetId(const Constant& c)
|
||||
{
|
||||
auto it = m_internal->ids.find(c.constant);
|
||||
if (it == m_internal->ids.end())
|
||||
throw std::runtime_error("constant is not registered");
|
||||
|
||||
return it->second;
|
||||
}
|
||||
|
||||
UInt32 SpirvConstantCache::GetId(const Type& t)
|
||||
{
|
||||
auto it = m_internal->ids.find(t.type);
|
||||
if (it == m_internal->ids.end())
|
||||
throw std::runtime_error("constant is not registered");
|
||||
|
||||
return it->second;
|
||||
}
|
||||
|
||||
UInt32 SpirvConstantCache::GetId(const Variable& v)
|
||||
{
|
||||
auto it = m_internal->variableIds.find(v);
|
||||
if (it == m_internal->variableIds.end())
|
||||
throw std::runtime_error("variable is not registered");
|
||||
|
||||
return it->second;
|
||||
}
|
||||
|
||||
UInt32 SpirvConstantCache::Register(Constant c)
|
||||
{
|
||||
AnyConstant& constant = c.constant;
|
||||
|
||||
DepRegisterer registerer(*this);
|
||||
registerer.Register(constant);
|
||||
|
||||
std::size_t h = m_internal->ids.hash_function()(constant);
|
||||
auto it = m_internal->ids.find(constant, h);
|
||||
if (it == m_internal->ids.end())
|
||||
{
|
||||
UInt32 resultId = m_internal->nextResultId++;
|
||||
it = m_internal->ids.emplace(std::move(constant), resultId).first;
|
||||
}
|
||||
|
||||
return it.value();
|
||||
}
|
||||
|
||||
UInt32 SpirvConstantCache::Register(Type t)
|
||||
{
|
||||
AnyType& type = t.type;
|
||||
if (std::holds_alternative<Identifier>(type))
|
||||
{
|
||||
assert(m_identifierCallback);
|
||||
return Register(*m_identifierCallback(std::get<Identifier>(type).name));
|
||||
}
|
||||
|
||||
DepRegisterer registerer(*this);
|
||||
registerer.Register(type);
|
||||
|
||||
std::size_t h = m_internal->ids.hash_function()(type);
|
||||
auto it = m_internal->ids.find(type, h);
|
||||
if (it == m_internal->ids.end())
|
||||
{
|
||||
UInt32 resultId = m_internal->nextResultId++;
|
||||
it = m_internal->ids.emplace(std::move(type), resultId).first;
|
||||
}
|
||||
|
||||
return it.value();
|
||||
}
|
||||
|
||||
UInt32 SpirvConstantCache::Register(Variable v)
|
||||
{
|
||||
DepRegisterer registerer(*this);
|
||||
registerer.Register(v);
|
||||
|
||||
std::size_t h = m_internal->variableIds.hash_function()(v);
|
||||
auto it = m_internal->variableIds.find(v, h);
|
||||
if (it == m_internal->variableIds.end())
|
||||
{
|
||||
UInt32 resultId = m_internal->nextResultId++;
|
||||
it = m_internal->variableIds.emplace(std::move(v), resultId).first;
|
||||
}
|
||||
|
||||
return it.value();
|
||||
}
|
||||
|
||||
void SpirvConstantCache::SetIdentifierCallback(IdentifierCallback callback)
|
||||
{
|
||||
m_identifierCallback = std::move(callback);
|
||||
}
|
||||
|
||||
void SpirvConstantCache::Write(SpirvSection& annotations, SpirvSection& constants, SpirvSection& debugInfos)
|
||||
{
|
||||
for (auto&& [object, id] : m_internal->ids)
|
||||
{
|
||||
UInt32 resultId = id;
|
||||
|
||||
std::visit(overloaded
|
||||
{
|
||||
[&](const AnyConstant& constant) { Write(constant, resultId, constants); },
|
||||
[&](const AnyType& type) { Write(type, resultId, annotations, constants, debugInfos); },
|
||||
}, object);
|
||||
}
|
||||
|
||||
for (auto&& [variable, id] : m_internal->variableIds)
|
||||
{
|
||||
const auto& var = variable;
|
||||
UInt32 resultId = id;
|
||||
|
||||
if (!variable.debugName.empty())
|
||||
debugInfos.Append(SpirvOp::OpName, resultId, variable.debugName);
|
||||
|
||||
constants.AppendVariadic(SpirvOp::OpVariable, [&](const auto& appender)
|
||||
{
|
||||
appender(GetId(*var.type));
|
||||
appender(resultId);
|
||||
appender(var.storageClass);
|
||||
|
||||
if (var.initializer)
|
||||
appender(GetId((*var.initializer)->constant));
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
SpirvConstantCache& SpirvConstantCache::operator=(SpirvConstantCache&& cache) noexcept = default;
|
||||
|
||||
auto SpirvConstantCache::BuildConstant(const ShaderConstantValue& value) -> ConstantPtr
|
||||
|
||||
auto SpirvConstantCache::BuildConstant(const ShaderAst::ConstantValue& value) const -> ConstantPtr
|
||||
{
|
||||
return std::make_shared<Constant>(std::visit([&](auto&& arg) -> SpirvConstantCache::AnyConstant
|
||||
{
|
||||
@@ -590,7 +460,7 @@ namespace Nz
|
||||
}, value));
|
||||
}
|
||||
|
||||
auto SpirvConstantCache::BuildFunctionType(const ShaderAst::ExpressionType& retType, const std::vector<ShaderAst::ExpressionType>& parameters) -> TypePtr
|
||||
auto SpirvConstantCache::BuildFunctionType(const ShaderAst::ExpressionType& retType, const std::vector<ShaderAst::ExpressionType>& parameters) const -> TypePtr
|
||||
{
|
||||
std::vector<SpirvConstantCache::TypePtr> parameterTypes;
|
||||
parameterTypes.reserve(parameters.size());
|
||||
@@ -604,7 +474,7 @@ namespace Nz
|
||||
});
|
||||
}
|
||||
|
||||
auto SpirvConstantCache::BuildPointerType(const ShaderAst::ExpressionType& type, SpirvStorageClass storageClass) -> TypePtr
|
||||
auto SpirvConstantCache::BuildPointerType(const ShaderAst::ExpressionType& type, SpirvStorageClass storageClass) const -> TypePtr
|
||||
{
|
||||
return std::make_shared<Type>(Pointer{
|
||||
BuildType(type),
|
||||
@@ -612,7 +482,7 @@ namespace Nz
|
||||
});
|
||||
}
|
||||
|
||||
auto SpirvConstantCache::BuildPointerType(const ShaderAst::PrimitiveType& type, SpirvStorageClass storageClass) -> TypePtr
|
||||
auto SpirvConstantCache::BuildPointerType(const ShaderAst::PrimitiveType& type, SpirvStorageClass storageClass) const -> TypePtr
|
||||
{
|
||||
return std::make_shared<Type>(Pointer{
|
||||
BuildType(type),
|
||||
@@ -620,7 +490,7 @@ namespace Nz
|
||||
});
|
||||
}
|
||||
|
||||
auto SpirvConstantCache::BuildType(const ShaderAst::ExpressionType& type) -> TypePtr
|
||||
auto SpirvConstantCache::BuildType(const ShaderAst::ExpressionType& type) const -> TypePtr
|
||||
{
|
||||
return std::visit([&](auto&& arg) -> TypePtr
|
||||
{
|
||||
@@ -628,16 +498,13 @@ namespace Nz
|
||||
}, type);
|
||||
}
|
||||
|
||||
auto SpirvConstantCache::BuildType(const ShaderAst::IdentifierType& type) -> TypePtr
|
||||
auto SpirvConstantCache::BuildType(const ShaderAst::IdentifierType& type) const -> TypePtr
|
||||
{
|
||||
return std::make_shared<Type>(
|
||||
Identifier{
|
||||
type.name
|
||||
}
|
||||
);
|
||||
// No IdentifierType is expected (as they should have been resolved by now)
|
||||
throw std::runtime_error("unexpected identifier");
|
||||
}
|
||||
|
||||
auto SpirvConstantCache::BuildType(const ShaderAst::PrimitiveType& type) -> TypePtr
|
||||
auto SpirvConstantCache::BuildType(const ShaderAst::PrimitiveType& type) const -> TypePtr
|
||||
{
|
||||
return std::make_shared<Type>([&]() -> AnyType
|
||||
{
|
||||
@@ -657,7 +524,7 @@ namespace Nz
|
||||
}());
|
||||
}
|
||||
|
||||
auto SpirvConstantCache::BuildType(const ShaderAst::MatrixType& type) -> TypePtr
|
||||
auto SpirvConstantCache::BuildType(const ShaderAst::MatrixType& type) const -> TypePtr
|
||||
{
|
||||
return std::make_shared<Type>(
|
||||
Matrix{
|
||||
@@ -668,12 +535,12 @@ namespace Nz
|
||||
});
|
||||
}
|
||||
|
||||
auto SpirvConstantCache::BuildType(const ShaderAst::NoType& type) -> TypePtr
|
||||
auto SpirvConstantCache::BuildType(const ShaderAst::NoType& type) const -> TypePtr
|
||||
{
|
||||
return std::make_shared<Type>(Void{});
|
||||
}
|
||||
|
||||
auto SpirvConstantCache::BuildType(const ShaderAst::SamplerType& type) -> TypePtr
|
||||
auto SpirvConstantCache::BuildType(const ShaderAst::SamplerType& type) const -> TypePtr
|
||||
{
|
||||
//TODO
|
||||
auto imageType = Image{
|
||||
@@ -690,7 +557,13 @@ namespace Nz
|
||||
return std::make_shared<Type>(SampledImage{ std::make_shared<Type>(imageType) });
|
||||
}
|
||||
|
||||
auto SpirvConstantCache::BuildType(const ShaderAst::StructDescription& structDesc) -> TypePtr
|
||||
auto SpirvConstantCache::BuildType(const ShaderAst::StructType& type) const -> TypePtr
|
||||
{
|
||||
assert(m_internal->structCallback);
|
||||
return BuildType(m_internal->structCallback(type.structIndex));
|
||||
}
|
||||
|
||||
auto SpirvConstantCache::BuildType(const ShaderAst::StructDescription& structDesc) const -> TypePtr
|
||||
{
|
||||
Structure sType;
|
||||
sType.name = structDesc.name;
|
||||
@@ -705,11 +578,136 @@ namespace Nz
|
||||
return std::make_shared<Type>(std::move(sType));
|
||||
}
|
||||
|
||||
auto SpirvConstantCache::BuildType(const ShaderAst::VectorType& type) -> TypePtr
|
||||
auto SpirvConstantCache::BuildType(const ShaderAst::VectorType& type) const -> TypePtr
|
||||
{
|
||||
return std::make_shared<Type>(Vector{ BuildType(type.type), UInt32(type.componentCount) });
|
||||
}
|
||||
|
||||
auto SpirvConstantCache::BuildType(const ShaderAst::UniformType& type) const -> TypePtr
|
||||
{
|
||||
assert(std::holds_alternative<ShaderAst::StructType>(type.containedType));
|
||||
return BuildType(std::get<ShaderAst::StructType>(type.containedType));
|
||||
}
|
||||
|
||||
UInt32 SpirvConstantCache::GetId(const Constant& c)
|
||||
{
|
||||
auto it = m_internal->ids.find(c.constant);
|
||||
if (it == m_internal->ids.end())
|
||||
throw std::runtime_error("constant is not registered");
|
||||
|
||||
return it->second;
|
||||
}
|
||||
|
||||
UInt32 SpirvConstantCache::GetId(const Type& t)
|
||||
{
|
||||
auto it = m_internal->ids.find(t.type);
|
||||
if (it == m_internal->ids.end())
|
||||
throw std::runtime_error("type is not registered");
|
||||
|
||||
return it->second;
|
||||
}
|
||||
|
||||
UInt32 SpirvConstantCache::GetId(const Variable& v)
|
||||
{
|
||||
auto it = m_internal->variableIds.find(v);
|
||||
if (it == m_internal->variableIds.end())
|
||||
throw std::runtime_error("variable is not registered");
|
||||
|
||||
return it->second;
|
||||
}
|
||||
|
||||
UInt32 SpirvConstantCache::Register(Constant c)
|
||||
{
|
||||
AnyConstant& constant = c.constant;
|
||||
|
||||
DepRegisterer registerer(*this);
|
||||
registerer.Register(constant);
|
||||
|
||||
std::size_t h = m_internal->ids.hash_function()(constant);
|
||||
auto it = m_internal->ids.find(constant, h);
|
||||
if (it == m_internal->ids.end())
|
||||
{
|
||||
UInt32 resultId = m_internal->nextResultId++;
|
||||
it = m_internal->ids.emplace(std::move(constant), resultId).first;
|
||||
}
|
||||
|
||||
return it.value();
|
||||
}
|
||||
|
||||
UInt32 SpirvConstantCache::Register(Type t)
|
||||
{
|
||||
AnyType& type = t.type;
|
||||
|
||||
DepRegisterer registerer(*this);
|
||||
registerer.Register(type);
|
||||
|
||||
std::size_t h = m_internal->ids.hash_function()(type);
|
||||
auto it = m_internal->ids.find(type, h);
|
||||
if (it == m_internal->ids.end())
|
||||
{
|
||||
UInt32 resultId = m_internal->nextResultId++;
|
||||
it = m_internal->ids.emplace(std::move(type), resultId).first;
|
||||
}
|
||||
|
||||
return it.value();
|
||||
}
|
||||
|
||||
UInt32 SpirvConstantCache::Register(Variable v)
|
||||
{
|
||||
DepRegisterer registerer(*this);
|
||||
registerer.Register(v);
|
||||
|
||||
std::size_t h = m_internal->variableIds.hash_function()(v);
|
||||
auto it = m_internal->variableIds.find(v, h);
|
||||
if (it == m_internal->variableIds.end())
|
||||
{
|
||||
UInt32 resultId = m_internal->nextResultId++;
|
||||
it = m_internal->variableIds.emplace(std::move(v), resultId).first;
|
||||
}
|
||||
|
||||
return it.value();
|
||||
}
|
||||
|
||||
void SpirvConstantCache::SetStructCallback(StructCallback callback)
|
||||
{
|
||||
m_internal->structCallback = std::move(callback);
|
||||
}
|
||||
|
||||
void SpirvConstantCache::Write(SpirvSection& annotations, SpirvSection& constants, SpirvSection& debugInfos)
|
||||
{
|
||||
for (auto&& [object, id] : m_internal->ids)
|
||||
{
|
||||
UInt32 resultId = id;
|
||||
|
||||
std::visit(overloaded
|
||||
{
|
||||
[&](const AnyConstant& constant) { Write(constant, resultId, constants); },
|
||||
[&](const AnyType& type) { Write(type, resultId, annotations, constants, debugInfos); },
|
||||
}, object);
|
||||
}
|
||||
|
||||
for (auto&& [variable, id] : m_internal->variableIds)
|
||||
{
|
||||
const auto& var = variable;
|
||||
UInt32 resultId = id;
|
||||
|
||||
if (!variable.debugName.empty())
|
||||
debugInfos.Append(SpirvOp::OpName, resultId, variable.debugName);
|
||||
|
||||
constants.AppendVariadic(SpirvOp::OpVariable, [&](const auto& appender)
|
||||
{
|
||||
appender(GetId(*var.type));
|
||||
appender(resultId);
|
||||
appender(var.storageClass);
|
||||
|
||||
if (var.initializer)
|
||||
appender(GetId((*var.initializer)->constant));
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
SpirvConstantCache& SpirvConstantCache::operator=(SpirvConstantCache&& cache) noexcept = default;
|
||||
|
||||
void SpirvConstantCache::Write(const AnyConstant& constant, UInt32 resultId, SpirvSection& constants)
|
||||
{
|
||||
std::visit([&](auto&& arg)
|
||||
|
||||
@@ -38,6 +38,8 @@ namespace Nz
|
||||
|
||||
while (m_currentCodepoint < m_codepointEnd)
|
||||
{
|
||||
const UInt32* instructionBegin = m_currentCodepoint;
|
||||
|
||||
UInt32 firstWord = ReadWord();
|
||||
|
||||
UInt16 wordCount = static_cast<UInt16>((firstWord >> 16) & 0xFFFF);
|
||||
@@ -50,7 +52,7 @@ namespace Nz
|
||||
if (!HandleOpcode(*inst, wordCount))
|
||||
break;
|
||||
|
||||
m_currentCodepoint += wordCount - 1;
|
||||
m_currentCodepoint = instructionBegin + wordCount;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
// For conditions of distribution and use, see copyright notice in Config.hpp
|
||||
|
||||
#include <Nazara/Shader/SpirvExpressionLoad.hpp>
|
||||
#include <Nazara/Core/StackVector.hpp>
|
||||
#include <Nazara/Shader/SpirvAstVisitor.hpp>
|
||||
#include <Nazara/Shader/SpirvBlock.hpp>
|
||||
#include <Nazara/Shader/SpirvWriter.hpp>
|
||||
#include <Nazara/Shader/Debug.hpp>
|
||||
@@ -24,9 +24,8 @@ namespace Nz
|
||||
{
|
||||
[this](const Pointer& pointer) -> UInt32
|
||||
{
|
||||
UInt32 resultId = m_writer.AllocateResultId();
|
||||
|
||||
m_block.Append(SpirvOp::OpLoad, pointer.pointedTypeId, resultId, pointer.resultId);
|
||||
UInt32 resultId = m_visitor.AllocateResultId();
|
||||
m_block.Append(SpirvOp::OpLoad, pointer.pointedTypeId, resultId, pointer.pointerId);
|
||||
|
||||
return resultId;
|
||||
},
|
||||
@@ -41,25 +40,26 @@ namespace Nz
|
||||
}, m_value);
|
||||
}
|
||||
|
||||
/*void SpirvExpressionLoad::Visit(ShaderAst::AccessMemberExpression& node)
|
||||
void SpirvExpressionLoad::Visit(ShaderAst::AccessMemberIndexExpression& node)
|
||||
{
|
||||
Visit(node.structExpr);
|
||||
node.structExpr->Visit(*this);
|
||||
|
||||
const ShaderAst::ExpressionType& exprType = GetExpressionType(node);
|
||||
|
||||
UInt32 resultId = m_visitor.AllocateResultId();
|
||||
UInt32 typeId = m_writer.GetTypeId(exprType);
|
||||
|
||||
std::visit(overloaded
|
||||
{
|
||||
[&](const Pointer& pointer)
|
||||
{
|
||||
ShaderAst::ShaderExpressionType exprType = GetExpressionType(node.structExpr);
|
||||
|
||||
UInt32 resultId = m_writer.AllocateResultId();
|
||||
UInt32 pointerType = m_writer.RegisterPointerType(node.exprType, pointer.storage); //< FIXME
|
||||
UInt32 typeId = m_writer.GetTypeId(node.exprType);
|
||||
UInt32 pointerType = m_writer.RegisterPointerType(exprType, pointer.storage); //< FIXME
|
||||
|
||||
m_block.AppendVariadic(SpirvOp::OpAccessChain, [&](const auto& appender)
|
||||
{
|
||||
appender(pointerType);
|
||||
appender(resultId);
|
||||
appender(pointer.resultId);
|
||||
appender(pointer.pointerId);
|
||||
|
||||
for (std::size_t index : node.memberIndices)
|
||||
appender(m_writer.GetConstantId(Int32(index)));
|
||||
@@ -69,9 +69,6 @@ namespace Nz
|
||||
},
|
||||
[&](const Value& value)
|
||||
{
|
||||
UInt32 resultId = m_writer.AllocateResultId();
|
||||
UInt32 typeId = m_writer.GetTypeId(node.exprType);
|
||||
|
||||
m_block.AppendVariadic(SpirvOp::OpCompositeExtract, [&](const auto& appender)
|
||||
{
|
||||
appender(typeId);
|
||||
@@ -89,15 +86,11 @@ namespace Nz
|
||||
throw std::runtime_error("an internal error occurred");
|
||||
}
|
||||
}, m_value);
|
||||
}*/
|
||||
}
|
||||
|
||||
void SpirvExpressionLoad::Visit(ShaderAst::IdentifierExpression& node)
|
||||
void SpirvExpressionLoad::Visit(ShaderAst::VariableExpression& node)
|
||||
{
|
||||
if (node.identifier == "d")
|
||||
m_value = Value{ m_writer.ReadLocalVariable(node.identifier) };
|
||||
else
|
||||
m_value = Value{ m_writer.ReadParameterVariable(node.identifier) };
|
||||
|
||||
//Visit(node.var);
|
||||
const auto& var = m_visitor.GetVariable(node.variableId);
|
||||
m_value = Pointer{ var.storage, var.pointerId, var.pointedTypeId };
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
// For conditions of distribution and use, see copyright notice in Config.hpp
|
||||
|
||||
#include <Nazara/Shader/SpirvExpressionStore.hpp>
|
||||
#include <Nazara/Shader/SpirvAstVisitor.hpp>
|
||||
#include <Nazara/Shader/SpirvBlock.hpp>
|
||||
#include <Nazara/Shader/SpirvWriter.hpp>
|
||||
#include <Nazara/Shader/Debug.hpp>
|
||||
@@ -23,11 +24,11 @@ namespace Nz
|
||||
{
|
||||
[&](const Pointer& pointer)
|
||||
{
|
||||
m_block.Append(SpirvOp::OpStore, pointer.resultId, resultId);
|
||||
m_block.Append(SpirvOp::OpStore, pointer.pointerId, resultId);
|
||||
},
|
||||
[&](const LocalVar& value)
|
||||
{
|
||||
m_writer.WriteLocalVariable(value.varName, resultId);
|
||||
throw std::runtime_error("not yet implemented");
|
||||
},
|
||||
[](std::monostate)
|
||||
{
|
||||
@@ -36,49 +37,50 @@ namespace Nz
|
||||
}, m_value);
|
||||
}
|
||||
|
||||
/*void SpirvExpressionStore::Visit(ShaderAst::AccessMemberExpression& node)
|
||||
void SpirvExpressionStore::Visit(ShaderAst::AccessMemberIndexExpression& node)
|
||||
{
|
||||
Visit(node.structExpr);
|
||||
node.structExpr->Visit(*this);
|
||||
|
||||
const ShaderAst::ExpressionType& exprType = GetExpressionType(node);
|
||||
|
||||
std::visit(overloaded
|
||||
{
|
||||
[&](const Pointer& pointer) -> UInt32
|
||||
[&](const Pointer& pointer)
|
||||
{
|
||||
UInt32 resultId = m_writer.AllocateResultId();
|
||||
UInt32 pointerType = m_writer.RegisterPointerType(node.exprType, pointer.storage); //< FIXME
|
||||
UInt32 resultId = m_visitor.AllocateResultId();
|
||||
UInt32 pointerType = m_writer.RegisterPointerType(exprType, pointer.storage); //< FIXME
|
||||
|
||||
m_block.AppendVariadic(SpirvOp::OpAccessChain, [&](const auto& appender)
|
||||
{
|
||||
appender(pointerType);
|
||||
appender(resultId);
|
||||
appender(pointer.resultId);
|
||||
appender(pointer.pointerId);
|
||||
|
||||
for (std::size_t index : node.memberIndices)
|
||||
appender(m_writer.GetConstantId(Int32(index)));
|
||||
});
|
||||
|
||||
m_value = Pointer{ pointer.storage, resultId };
|
||||
|
||||
return resultId;
|
||||
m_value = Pointer { pointer.storage, resultId };
|
||||
},
|
||||
[](const LocalVar& value) -> UInt32
|
||||
[&](const LocalVar& value)
|
||||
{
|
||||
throw std::runtime_error("not yet implemented");
|
||||
},
|
||||
[](std::monostate) -> UInt32
|
||||
[](std::monostate)
|
||||
{
|
||||
throw std::runtime_error("an internal error occurred");
|
||||
}
|
||||
}, m_value);
|
||||
}*/
|
||||
|
||||
void SpirvExpressionStore::Visit(ShaderAst::IdentifierExpression& node)
|
||||
{
|
||||
m_value = LocalVar{ node.identifier };
|
||||
}
|
||||
|
||||
void SpirvExpressionStore::Visit(ShaderAst::SwizzleExpression& node)
|
||||
{
|
||||
throw std::runtime_error("not yet implemented");
|
||||
}
|
||||
|
||||
void SpirvExpressionStore::Visit(ShaderAst::VariableExpression& node)
|
||||
{
|
||||
const auto& var = m_visitor.GetVariable(node.variableId);
|
||||
m_value = Pointer{ var.storage, var.pointerId };
|
||||
}
|
||||
}
|
||||
|
||||
@@ -68,7 +68,7 @@ namespace Nz
|
||||
UInt32 resultId = 0;
|
||||
|
||||
std::size_t currentOperand = 0;
|
||||
const UInt32* endPtr = startPtr + wordCount;
|
||||
const UInt32* endPtr = startPtr + wordCount - 1;
|
||||
while (GetCurrentPtr() < endPtr)
|
||||
{
|
||||
const SpirvInstruction::Operand* operand = &instruction.operands[currentOperand];
|
||||
@@ -209,7 +209,7 @@ namespace Nz
|
||||
|
||||
m_currentState->stream << "\n";
|
||||
|
||||
assert(GetCurrentPtr() == startPtr + wordCount);
|
||||
assert(GetCurrentPtr() == startPtr + wordCount - 1);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -12,10 +12,12 @@
|
||||
#include <Nazara/Shader/SpirvConstantCache.hpp>
|
||||
#include <Nazara/Shader/SpirvData.hpp>
|
||||
#include <Nazara/Shader/SpirvSection.hpp>
|
||||
#include <Nazara/Shader/Ast/TransformVisitor.hpp>
|
||||
#include <tsl/ordered_map.h>
|
||||
#include <tsl/ordered_set.h>
|
||||
#include <SpirV/GLSL.std.450.h>
|
||||
#include <cassert>
|
||||
#include <map>
|
||||
#include <stdexcept>
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
@@ -25,34 +27,61 @@ namespace Nz
|
||||
{
|
||||
namespace
|
||||
{
|
||||
//FIXME: Have this only once
|
||||
std::unordered_map<std::string, ShaderStageType> s_entryPoints = {
|
||||
{ "frag", ShaderStageType::Fragment },
|
||||
{ "vert", ShaderStageType::Vertex },
|
||||
};
|
||||
|
||||
struct Builtin
|
||||
{
|
||||
const char* debugName;
|
||||
ShaderStageTypeFlags compatibleStages;
|
||||
SpirvBuiltIn decoration;
|
||||
};
|
||||
|
||||
std::unordered_map<std::string, Builtin> s_builtinMapping = {
|
||||
{ "position", { "VertexPosition", ShaderStageType::Vertex, SpirvBuiltIn::Position } }
|
||||
};
|
||||
|
||||
class PreVisitor : public ShaderAst::AstScopedVisitor
|
||||
{
|
||||
public:
|
||||
struct UniformVar
|
||||
{
|
||||
std::optional<UInt32> bindingIndex;
|
||||
UInt32 pointerId;
|
||||
};
|
||||
|
||||
using BuiltinDecoration = std::map<UInt32, SpirvBuiltIn>;
|
||||
using LocationDecoration = std::map<UInt32, UInt32>;
|
||||
using ExtInstList = std::unordered_set<std::string>;
|
||||
using ExtVarContainer = std::unordered_map<std::size_t /*varIndex*/, UniformVar>;
|
||||
using LocalContainer = std::unordered_set<ShaderAst::ExpressionType>;
|
||||
using FunctionContainer = std::vector<std::reference_wrapper<ShaderAst::DeclareFunctionStatement>>;
|
||||
using StructContainer = std::vector<ShaderAst::StructDescription>;
|
||||
|
||||
PreVisitor(const SpirvWriter::States& conditions, SpirvConstantCache& constantCache) :
|
||||
PreVisitor(const SpirvWriter::States& conditions, SpirvConstantCache& constantCache, std::vector<SpirvAstVisitor::FuncData>& funcs) :
|
||||
m_conditions(conditions),
|
||||
m_constantCache(constantCache)
|
||||
m_constantCache(constantCache),
|
||||
m_externalBlockIndex(0),
|
||||
m_funcs(funcs)
|
||||
{
|
||||
m_constantCache.SetIdentifierCallback([&](const std::string& identifierName)
|
||||
m_constantCache.SetStructCallback([this](std::size_t structIndex) -> const ShaderAst::StructDescription&
|
||||
{
|
||||
const Identifier* identifier = FindIdentifier(identifierName);
|
||||
if (!identifier)
|
||||
throw std::runtime_error("invalid identifier " + identifierName);
|
||||
|
||||
assert(std::holds_alternative<ShaderAst::StructDescription>(identifier->value));
|
||||
return SpirvConstantCache::BuildType(std::get<ShaderAst::StructDescription>(identifier->value));
|
||||
assert(structIndex < declaredStructs.size());
|
||||
return declaredStructs[structIndex];
|
||||
});
|
||||
}
|
||||
|
||||
void Visit(ShaderAst::AccessMemberExpression& node) override
|
||||
void Visit(ShaderAst::AccessMemberIndexExpression& node) override
|
||||
{
|
||||
/*for (std::size_t index : node.memberIdentifiers)
|
||||
m_constantCache.Register(*SpirvConstantCache::BuildConstant(Int32(index)));*/
|
||||
|
||||
AstRecursiveVisitor::Visit(node);
|
||||
|
||||
for (std::size_t index : node.memberIndices)
|
||||
m_constantCache.Register(*m_constantCache.BuildConstant(Int32(index)));
|
||||
|
||||
m_constantCache.Register(*m_constantCache.BuildType(node.cachedExpressionType.value()));
|
||||
}
|
||||
|
||||
void Visit(ShaderAst::ConditionalExpression& node) override
|
||||
@@ -64,6 +93,8 @@ namespace Nz
|
||||
Visit(node.truePath);
|
||||
else
|
||||
Visit(node.falsePath);*/
|
||||
|
||||
m_constantCache.Register(*m_constantCache.BuildType(node.cachedExpressionType.value()));
|
||||
}
|
||||
|
||||
void Visit(ShaderAst::ConditionalStatement& node) override
|
||||
@@ -79,52 +110,189 @@ namespace Nz
|
||||
{
|
||||
std::visit([&](auto&& arg)
|
||||
{
|
||||
m_constantCache.Register(*SpirvConstantCache::BuildConstant(arg));
|
||||
m_constantCache.Register(*m_constantCache.BuildConstant(arg));
|
||||
}, node.value);
|
||||
|
||||
AstScopedVisitor::Visit(node);
|
||||
}
|
||||
|
||||
void Visit(ShaderAst::DeclareExternalStatement& node) override
|
||||
{
|
||||
assert(node.varIndex);
|
||||
std::size_t varIndex = *node.varIndex;
|
||||
for (auto& extVar : node.externalVars)
|
||||
{
|
||||
SpirvConstantCache::Variable variable;
|
||||
variable.debugName = extVar.name;
|
||||
variable.storageClass = (ShaderAst::IsSamplerType(extVar.type)) ? SpirvStorageClass::UniformConstant : SpirvStorageClass::Uniform;
|
||||
variable.type = m_constantCache.BuildPointerType(extVar.type, variable.storageClass);
|
||||
|
||||
UniformVar& uniformVar = extVars[varIndex++];
|
||||
uniformVar.pointerId = m_constantCache.Register(variable);
|
||||
|
||||
for (const auto& [attributeType, attributeParam] : extVar.attributes)
|
||||
{
|
||||
if (attributeType == ShaderAst::AttributeType::Binding)
|
||||
{
|
||||
uniformVar.bindingIndex = std::get<long long>(attributeParam);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void Visit(ShaderAst::DeclareFunctionStatement& node) override
|
||||
{
|
||||
funcs.emplace_back(node);
|
||||
std::optional<ShaderStageType> entryPointType;
|
||||
for (auto& attribute : node.attributes)
|
||||
{
|
||||
if (attribute.type == ShaderAst::AttributeType::Entry)
|
||||
{
|
||||
auto it = s_entryPoints.find(std::get<std::string>(attribute.args));
|
||||
assert(it != s_entryPoints.end());
|
||||
|
||||
std::vector<ShaderAst::ExpressionType> parameterTypes;
|
||||
for (auto& parameter : node.parameters)
|
||||
parameterTypes.push_back(parameter.type);
|
||||
entryPointType = it->second;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
m_constantCache.Register(*SpirvConstantCache::BuildFunctionType(node.returnType, parameterTypes));
|
||||
assert(node.funcIndex);
|
||||
std::size_t funcIndex = *node.funcIndex;
|
||||
|
||||
if (funcIndex >= m_funcs.size())
|
||||
m_funcs.resize(funcIndex + 1);
|
||||
|
||||
auto& funcData = m_funcs[funcIndex];
|
||||
funcData.name = node.name;
|
||||
|
||||
if (!entryPointType)
|
||||
{
|
||||
std::vector<ShaderAst::ExpressionType> parameterTypes;
|
||||
for (auto& parameter : node.parameters)
|
||||
parameterTypes.push_back(parameter.type);
|
||||
|
||||
funcData.returnTypeId = m_constantCache.Register(*m_constantCache.BuildType(node.returnType));
|
||||
funcData.funcTypeId = m_constantCache.Register(*m_constantCache.BuildFunctionType(node.returnType, parameterTypes));
|
||||
|
||||
for (auto& parameter : node.parameters)
|
||||
{
|
||||
auto& funcParam = funcData.parameters.emplace_back();
|
||||
funcParam.pointerTypeId = m_constantCache.Register(*m_constantCache.BuildPointerType(parameter.type, SpirvStorageClass::Function));
|
||||
funcParam.typeId = m_constantCache.Register(*m_constantCache.BuildType(parameter.type));
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
using EntryPoint = SpirvAstVisitor::EntryPoint;
|
||||
|
||||
funcData.returnTypeId = m_constantCache.Register(*m_constantCache.BuildType(ShaderAst::NoType{}));
|
||||
funcData.funcTypeId = m_constantCache.Register(*m_constantCache.BuildFunctionType(ShaderAst::NoType{}, {}));
|
||||
|
||||
std::optional<EntryPoint::InputStruct> inputStruct;
|
||||
std::vector<EntryPoint::Input> inputs;
|
||||
if (!node.parameters.empty())
|
||||
{
|
||||
assert(node.parameters.size() == 1);
|
||||
auto& parameter = node.parameters.front();
|
||||
assert(std::holds_alternative<ShaderAst::StructType>(parameter.type));
|
||||
|
||||
std::size_t structIndex = std::get<ShaderAst::StructType>(parameter.type).structIndex;
|
||||
const ShaderAst::StructDescription& structDesc = declaredStructs[structIndex];
|
||||
|
||||
std::size_t memberIndex = 0;
|
||||
for (const auto& member : structDesc.members)
|
||||
{
|
||||
if (UInt32 varId = HandleEntryInOutType(*entryPointType, funcIndex, member, SpirvStorageClass::Input); varId != 0)
|
||||
{
|
||||
inputs.push_back({
|
||||
m_constantCache.Register(*m_constantCache.BuildConstant(Int32(memberIndex))),
|
||||
m_constantCache.Register(*m_constantCache.BuildPointerType(member.type, SpirvStorageClass::Function)),
|
||||
varId
|
||||
});
|
||||
}
|
||||
|
||||
memberIndex++;
|
||||
}
|
||||
|
||||
inputStruct = EntryPoint::InputStruct{
|
||||
m_constantCache.Register(*m_constantCache.BuildPointerType(parameter.type, SpirvStorageClass::Function)),
|
||||
m_constantCache.Register(*m_constantCache.BuildType(parameter.type))
|
||||
};
|
||||
}
|
||||
|
||||
std::optional<UInt32> outputStructId;
|
||||
std::vector<EntryPoint::Output> outputs;
|
||||
if (!IsNoType(node.returnType))
|
||||
{
|
||||
assert(std::holds_alternative<ShaderAst::StructType>(node.returnType));
|
||||
|
||||
std::size_t structIndex = std::get<ShaderAst::StructType>(node.returnType).structIndex;
|
||||
const ShaderAst::StructDescription& structDesc = declaredStructs[structIndex];
|
||||
|
||||
std::size_t memberIndex = 0;
|
||||
for (const auto& member : structDesc.members)
|
||||
{
|
||||
if (UInt32 varId = HandleEntryInOutType(*entryPointType, funcIndex, member, SpirvStorageClass::Output); varId != 0)
|
||||
{
|
||||
outputs.push_back({
|
||||
Int32(memberIndex),
|
||||
m_constantCache.Register(*m_constantCache.BuildType(member.type)),
|
||||
varId
|
||||
});
|
||||
}
|
||||
|
||||
memberIndex++;
|
||||
}
|
||||
|
||||
outputStructId = m_constantCache.Register(*m_constantCache.BuildType(node.returnType));
|
||||
}
|
||||
|
||||
funcData.entryPointData = EntryPoint{
|
||||
*entryPointType,
|
||||
inputStruct,
|
||||
outputStructId,
|
||||
funcIndex,
|
||||
std::move(inputs),
|
||||
std::move(outputs)
|
||||
};
|
||||
}
|
||||
|
||||
m_funcIndex = funcIndex;
|
||||
AstScopedVisitor::Visit(node);
|
||||
m_funcIndex.reset();
|
||||
}
|
||||
|
||||
void Visit(ShaderAst::DeclareStructStatement& node) override
|
||||
{
|
||||
AstScopedVisitor::Visit(node);
|
||||
|
||||
SpirvConstantCache::Structure sType;
|
||||
sType.name = node.description.name;
|
||||
assert(node.structIndex);
|
||||
std::size_t structIndex = *node.structIndex;
|
||||
if (structIndex >= declaredStructs.size())
|
||||
declaredStructs.resize(structIndex + 1);
|
||||
|
||||
for (const auto& [name, attribute, type] : node.description.members)
|
||||
{
|
||||
auto& sMembers = sType.members.emplace_back();
|
||||
sMembers.name = name;
|
||||
sMembers.type = SpirvConstantCache::BuildType(type);
|
||||
}
|
||||
declaredStructs[structIndex] = node.description;
|
||||
|
||||
m_constantCache.Register(SpirvConstantCache::Type{ std::move(sType) });
|
||||
m_constantCache.Register(*m_constantCache.BuildType(node.description));
|
||||
}
|
||||
|
||||
void Visit(ShaderAst::DeclareVariableStatement& node) override
|
||||
{
|
||||
AstScopedVisitor::Visit(node);
|
||||
|
||||
m_constantCache.Register(*SpirvConstantCache::BuildType(node.varType));
|
||||
assert(m_funcIndex);
|
||||
auto& func = m_funcs[*m_funcIndex];
|
||||
|
||||
assert(node.varIndex);
|
||||
func.varIndexToVarId[*node.varIndex] = func.variables.size();
|
||||
|
||||
auto& var = func.variables.emplace_back();
|
||||
var.typeId = m_constantCache.Register(*m_constantCache.BuildPointerType(node.varType, SpirvStorageClass::Function));
|
||||
}
|
||||
|
||||
void Visit(ShaderAst::IdentifierExpression& node) override
|
||||
{
|
||||
m_constantCache.Register(*SpirvConstantCache::BuildType(node.cachedExpressionType.value()));
|
||||
m_constantCache.Register(*m_constantCache.BuildType(node.cachedExpressionType.value()));
|
||||
|
||||
AstScopedVisitor::Visit(node);
|
||||
}
|
||||
@@ -144,40 +312,88 @@ namespace Nz
|
||||
case ShaderAst::IntrinsicType::DotProduct:
|
||||
break;
|
||||
}
|
||||
|
||||
m_constantCache.Register(*m_constantCache.BuildType(node.cachedExpressionType.value()));
|
||||
}
|
||||
|
||||
void Visit(ShaderAst::SwizzleExpression& node) override
|
||||
{
|
||||
AstScopedVisitor::Visit(node);
|
||||
|
||||
m_constantCache.Register(*m_constantCache.BuildType(node.cachedExpressionType.value()));
|
||||
}
|
||||
|
||||
UInt32 HandleEntryInOutType(ShaderStageType entryPointType, std::size_t funcIndex, const ShaderAst::StructDescription::StructMember& member, SpirvStorageClass storageClass)
|
||||
{
|
||||
std::optional<std::reference_wrapper<Builtin>> builtinOpt;
|
||||
std::optional<long long> attributeLocation;
|
||||
for (const auto& [attributeType, attributeParam] : member.attributes)
|
||||
{
|
||||
if (attributeType == ShaderAst::AttributeType::Builtin)
|
||||
{
|
||||
auto it = s_builtinMapping.find(std::get<std::string>(attributeParam));
|
||||
if (it != s_builtinMapping.end())
|
||||
{
|
||||
builtinOpt = it->second;
|
||||
break;
|
||||
}
|
||||
}
|
||||
else if (attributeType == ShaderAst::AttributeType::Location)
|
||||
{
|
||||
attributeLocation = std::get<long long>(attributeParam);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (builtinOpt)
|
||||
{
|
||||
Builtin& builtin = *builtinOpt;
|
||||
if ((builtin.compatibleStages & entryPointType) == 0)
|
||||
return 0;
|
||||
|
||||
SpirvBuiltIn builtinDecoration = builtin.decoration;
|
||||
|
||||
SpirvConstantCache::Variable variable;
|
||||
variable.debugName = builtin.debugName;
|
||||
variable.funcId = funcIndex;
|
||||
variable.storageClass = storageClass;
|
||||
variable.type = m_constantCache.BuildPointerType(member.type, storageClass);
|
||||
|
||||
UInt32 varId = m_constantCache.Register(variable);
|
||||
builtinDecorations[varId] = builtinDecoration;
|
||||
|
||||
return varId;
|
||||
}
|
||||
else if (attributeLocation)
|
||||
{
|
||||
SpirvConstantCache::Variable variable;
|
||||
variable.debugName = member.name;
|
||||
variable.funcId = funcIndex;
|
||||
variable.storageClass = storageClass;
|
||||
variable.type = m_constantCache.BuildPointerType(member.type, storageClass);
|
||||
|
||||
UInt32 varId = m_constantCache.Register(variable);
|
||||
locationDecorations[varId] = *attributeLocation;
|
||||
|
||||
return varId;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
BuiltinDecoration builtinDecorations;
|
||||
ExtInstList extInsts;
|
||||
FunctionContainer funcs;
|
||||
ExtVarContainer extVars;
|
||||
LocationDecoration locationDecorations;
|
||||
StructContainer declaredStructs;
|
||||
|
||||
private:
|
||||
const SpirvWriter::States& m_conditions;
|
||||
SpirvConstantCache& m_constantCache;
|
||||
std::optional<std::size_t> m_funcIndex;
|
||||
std::size_t m_externalBlockIndex;
|
||||
std::vector<SpirvAstVisitor::FuncData>& m_funcs;
|
||||
};
|
||||
|
||||
template<typename T>
|
||||
constexpr ShaderAst::PrimitiveType GetBasicType()
|
||||
{
|
||||
if constexpr (std::is_same_v<T, bool>)
|
||||
return ShaderAst::PrimitiveType::Boolean;
|
||||
else if constexpr (std::is_same_v<T, float>)
|
||||
return(ShaderAst::PrimitiveType::Float32);
|
||||
else if constexpr (std::is_same_v<T, Int32>)
|
||||
return(ShaderAst::PrimitiveType::Int32);
|
||||
else if constexpr (std::is_same_v<T, Vector2f>)
|
||||
return(ShaderAst::PrimitiveType::Float2);
|
||||
else if constexpr (std::is_same_v<T, Vector3f>)
|
||||
return(ShaderAst::PrimitiveType::Float3);
|
||||
else if constexpr (std::is_same_v<T, Vector4f>)
|
||||
return(ShaderAst::PrimitiveType::Float4);
|
||||
else if constexpr (std::is_same_v<T, Vector2i32>)
|
||||
return(ShaderAst::PrimitiveType::Int2);
|
||||
else if constexpr (std::is_same_v<T, Vector3i32>)
|
||||
return(ShaderAst::PrimitiveType::Int3);
|
||||
else if constexpr (std::is_same_v<T, Vector4i32>)
|
||||
return(ShaderAst::PrimitiveType::Int4);
|
||||
else
|
||||
static_assert(AlwaysFalse<T>::value, "unhandled type");
|
||||
}
|
||||
}
|
||||
|
||||
struct SpirvWriter::State
|
||||
@@ -194,18 +410,13 @@ namespace Nz
|
||||
UInt32 id;
|
||||
};
|
||||
|
||||
tsl::ordered_map<std::string, ExtVar> inputIds;
|
||||
tsl::ordered_map<std::string, ExtVar> outputIds;
|
||||
tsl::ordered_map<std::string, ExtVar> parameterIds;
|
||||
tsl::ordered_map<std::string, ExtVar> uniformIds;
|
||||
std::unordered_map<std::string, UInt32> extensionInstructions;
|
||||
std::unordered_map<ShaderAst::BuiltinEntry, ExtVar> builtinIds;
|
||||
std::unordered_map<std::string, UInt32> varToResult;
|
||||
std::vector<Func> funcs;
|
||||
std::vector<SpirvBlock> functionBlocks;
|
||||
std::vector<SpirvAstVisitor::FuncData> funcs;
|
||||
std::vector<UInt32> resultIds;
|
||||
UInt32 nextVarIndex = 1;
|
||||
SpirvConstantCache constantTypeCache; //< init after nextVarIndex
|
||||
PreVisitor* preVisitor;
|
||||
|
||||
// Output
|
||||
SpirvSection header;
|
||||
@@ -226,6 +437,9 @@ namespace Nz
|
||||
if (!ShaderAst::ValidateAst(shader, &error))
|
||||
throw std::runtime_error("Invalid shader AST: " + error);
|
||||
|
||||
ShaderAst::TransformVisitor transformVisitor;
|
||||
ShaderAst::StatementPtr transformedShader = transformVisitor.Transform(shader);
|
||||
|
||||
m_context.states = &conditions;
|
||||
|
||||
State state;
|
||||
@@ -235,245 +449,37 @@ namespace Nz
|
||||
m_currentState = nullptr;
|
||||
});
|
||||
|
||||
ShaderAst::AstCloner cloner;
|
||||
|
||||
// Register all extended instruction sets
|
||||
PreVisitor preVisitor(conditions, state.constantTypeCache);
|
||||
shader->Visit(preVisitor);
|
||||
PreVisitor preVisitor(conditions, state.constantTypeCache, state.funcs);
|
||||
transformedShader->Visit(preVisitor);
|
||||
|
||||
m_currentState->preVisitor = &preVisitor;
|
||||
|
||||
for (const std::string& extInst : preVisitor.extInsts)
|
||||
state.extensionInstructions[extInst] = AllocateResultId();
|
||||
|
||||
// Register all types
|
||||
/*for (const auto& func : shader.GetFunctions())
|
||||
{
|
||||
RegisterType(func.returnType);
|
||||
for (const auto& param : func.parameters)
|
||||
RegisterType(param.type);
|
||||
}
|
||||
|
||||
for (const auto& input : shader.GetInputs())
|
||||
RegisterPointerType(input.type, SpirvStorageClass::Input);
|
||||
|
||||
for (const auto& output : shader.GetOutputs())
|
||||
RegisterPointerType(output.type, SpirvStorageClass::Output);
|
||||
|
||||
for (const auto& uniform : shader.GetUniforms())
|
||||
RegisterPointerType(uniform.type, (IsSamplerType(uniform.type)) ? SpirvStorageClass::UniformConstant : SpirvStorageClass::Uniform);
|
||||
|
||||
for (const auto& func : shader.GetFunctions())
|
||||
RegisterFunctionType(func.returnType, func.parameters);
|
||||
|
||||
for (const auto& type : preVisitor.variableTypes)
|
||||
RegisterType(type);
|
||||
|
||||
for (const auto& builtin : preVisitor.builtinVars)
|
||||
RegisterType(builtin->type);
|
||||
|
||||
// Register result id and debug infos for global variables/functions
|
||||
for (const auto& builtin : preVisitor.builtinVars)
|
||||
{
|
||||
SpirvConstantCache::Variable variable;
|
||||
SpirvBuiltIn builtinDecoration;
|
||||
switch (builtin->entry)
|
||||
{
|
||||
case ShaderAst::BuiltinEntry::VertexPosition:
|
||||
variable.debugName = "builtin_VertexPosition";
|
||||
variable.storageClass = SpirvStorageClass::Output;
|
||||
|
||||
builtinDecoration = SpirvBuiltIn::Position;
|
||||
break;
|
||||
|
||||
default:
|
||||
throw std::runtime_error("unexpected builtin type");
|
||||
}
|
||||
|
||||
const ShaderAst::ShaderExpressionType& builtinExprType = builtin->type;
|
||||
assert(IsBasicType(builtinExprType));
|
||||
|
||||
ShaderAst::BasicType builtinType = std::get<ShaderAst::BasicType>(builtinExprType);
|
||||
|
||||
variable.type = SpirvConstantCache::BuildPointerType(builtinType, variable.storageClass);
|
||||
|
||||
UInt32 varId = m_currentState->constantTypeCache.Register(variable);
|
||||
|
||||
ExtVar builtinData;
|
||||
builtinData.pointerTypeId = GetPointerTypeId(builtinType, variable.storageClass);
|
||||
builtinData.typeId = GetTypeId(builtinType);
|
||||
builtinData.varId = varId;
|
||||
|
||||
state.annotations.Append(SpirvOp::OpDecorate, builtinData.varId, SpirvDecoration::BuiltIn, builtinDecoration);
|
||||
|
||||
state.builtinIds.emplace(builtin->entry, builtinData);
|
||||
}
|
||||
|
||||
for (const auto& input : shader.GetInputs())
|
||||
{
|
||||
SpirvConstantCache::Variable variable;
|
||||
variable.debugName = input.name;
|
||||
variable.storageClass = SpirvStorageClass::Input;
|
||||
variable.type = SpirvConstantCache::BuildPointerType(shader, input.type, variable.storageClass);
|
||||
|
||||
UInt32 varId = m_currentState->constantTypeCache.Register(variable);
|
||||
|
||||
ExtVar inputData;
|
||||
inputData.pointerTypeId = GetPointerTypeId(input.type, variable.storageClass);
|
||||
inputData.typeId = GetTypeId(input.type);
|
||||
inputData.varId = varId;
|
||||
|
||||
state.inputIds.emplace(input.name, std::move(inputData));
|
||||
|
||||
if (input.locationIndex)
|
||||
state.annotations.Append(SpirvOp::OpDecorate, varId, SpirvDecoration::Location, *input.locationIndex);
|
||||
}
|
||||
|
||||
for (const auto& output : shader.GetOutputs())
|
||||
{
|
||||
SpirvConstantCache::Variable variable;
|
||||
variable.debugName = output.name;
|
||||
variable.storageClass = SpirvStorageClass::Output;
|
||||
variable.type = SpirvConstantCache::BuildPointerType(shader, output.type, variable.storageClass);
|
||||
|
||||
UInt32 varId = m_currentState->constantTypeCache.Register(variable);
|
||||
|
||||
ExtVar outputData;
|
||||
outputData.pointerTypeId = GetPointerTypeId(output.type, variable.storageClass);
|
||||
outputData.typeId = GetTypeId(output.type);
|
||||
outputData.varId = varId;
|
||||
|
||||
state.outputIds.emplace(output.name, std::move(outputData));
|
||||
|
||||
if (output.locationIndex)
|
||||
state.annotations.Append(SpirvOp::OpDecorate, varId, SpirvDecoration::Location, *output.locationIndex);
|
||||
}
|
||||
|
||||
for (const auto& uniform : shader.GetUniforms())
|
||||
{
|
||||
SpirvConstantCache::Variable variable;
|
||||
variable.debugName = uniform.name;
|
||||
variable.storageClass = (IsSamplerType(uniform.type)) ? SpirvStorageClass::UniformConstant : SpirvStorageClass::Uniform;
|
||||
variable.type = SpirvConstantCache::BuildPointerType(shader, uniform.type, variable.storageClass);
|
||||
|
||||
UInt32 varId = m_currentState->constantTypeCache.Register(variable);
|
||||
|
||||
ExtVar uniformData;
|
||||
uniformData.pointerTypeId = GetPointerTypeId(uniform.type, variable.storageClass);
|
||||
uniformData.typeId = GetTypeId(uniform.type);
|
||||
uniformData.varId = varId;
|
||||
|
||||
state.uniformIds.emplace(uniform.name, std::move(uniformData));
|
||||
|
||||
if (uniform.bindingIndex)
|
||||
{
|
||||
state.annotations.Append(SpirvOp::OpDecorate, varId, SpirvDecoration::Binding, *uniform.bindingIndex);
|
||||
state.annotations.Append(SpirvOp::OpDecorate, varId, SpirvDecoration::DescriptorSet, 0);
|
||||
}
|
||||
}*/
|
||||
|
||||
for (const ShaderAst::DeclareFunctionStatement& func : preVisitor.funcs)
|
||||
{
|
||||
auto& funcData = state.funcs.emplace_back();
|
||||
funcData.statement = &func;
|
||||
funcData.id = AllocateResultId();
|
||||
funcData.typeId = GetFunctionTypeId(func);
|
||||
|
||||
state.debugInfo.Append(SpirvOp::OpName, funcData.id, func.name);
|
||||
}
|
||||
|
||||
std::size_t funcIndex = 0;
|
||||
|
||||
for (const ShaderAst::DeclareFunctionStatement& func : preVisitor.funcs)
|
||||
{
|
||||
auto& funcData = state.funcs[funcIndex++];
|
||||
|
||||
state.instructions.Append(SpirvOp::OpFunction, GetTypeId(func.returnType), funcData.id, 0, funcData.typeId);
|
||||
|
||||
state.functionBlocks.clear();
|
||||
state.functionBlocks.emplace_back(*this);
|
||||
|
||||
state.parameterIds.clear();
|
||||
|
||||
for (const auto& param : func.parameters)
|
||||
{
|
||||
UInt32 paramResultId = AllocateResultId();
|
||||
state.instructions.Append(SpirvOp::OpFunctionParameter, GetTypeId(param.type), paramResultId);
|
||||
|
||||
ExtVar parameterData;
|
||||
parameterData.pointerTypeId = GetPointerTypeId(param.type, SpirvStorageClass::Function);
|
||||
parameterData.typeId = GetTypeId(param.type);
|
||||
parameterData.varId = paramResultId;
|
||||
|
||||
state.parameterIds.emplace(param.name, std::move(parameterData));
|
||||
}
|
||||
|
||||
SpirvAstVisitor visitor(*this, state.functionBlocks);
|
||||
for (const auto& statement : func.statements)
|
||||
statement->Visit(visitor);
|
||||
|
||||
if (!state.functionBlocks.back().IsTerminated())
|
||||
{
|
||||
assert(func.returnType == ShaderAst::ExpressionType{ ShaderAst::NoType{} });
|
||||
state.functionBlocks.back().Append(SpirvOp::OpReturn);
|
||||
}
|
||||
|
||||
for (SpirvBlock& block : state.functionBlocks)
|
||||
state.instructions.AppendSection(block);
|
||||
|
||||
state.instructions.Append(SpirvOp::OpFunctionEnd);
|
||||
}
|
||||
|
||||
m_currentState->constantTypeCache.Write(m_currentState->annotations, m_currentState->constants, m_currentState->debugInfo);
|
||||
SpirvAstVisitor visitor(*this, state.instructions, state.funcs);
|
||||
transformedShader->Visit(visitor);
|
||||
|
||||
AppendHeader();
|
||||
|
||||
for (std::size_t i = 0; i < ShaderStageTypeCount; ++i)
|
||||
for (auto&& [varIndex, extVar] : preVisitor.extVars)
|
||||
{
|
||||
/*const ShaderAst::DeclareFunctionStatement* statement = m_context.cache.entryFunctions[i];
|
||||
if (!statement)
|
||||
continue;
|
||||
|
||||
auto it = std::find_if(state.funcs.begin(), state.funcs.end(), [&](const auto& funcData) { return funcData.statement == statement; });
|
||||
assert(it != state.funcs.end());
|
||||
|
||||
const auto& entryFunc = *it;
|
||||
|
||||
SpirvExecutionModel execModel;
|
||||
|
||||
ShaderStageType stage = static_cast<ShaderStageType>(i);
|
||||
switch (stage)
|
||||
if (extVar.bindingIndex)
|
||||
{
|
||||
case ShaderStageType::Fragment:
|
||||
execModel = SpirvExecutionModel::Fragment;
|
||||
break;
|
||||
|
||||
case ShaderStageType::Vertex:
|
||||
execModel = SpirvExecutionModel::Vertex;
|
||||
break;
|
||||
|
||||
default:
|
||||
throw std::runtime_error("not yet implemented");
|
||||
state.annotations.Append(SpirvOp::OpDecorate, extVar.pointerId, SpirvDecoration::Binding, *extVar.bindingIndex);
|
||||
state.annotations.Append(SpirvOp::OpDecorate, extVar.pointerId, SpirvDecoration::DescriptorSet, 0);
|
||||
}
|
||||
|
||||
state.header.AppendVariadic(SpirvOp::OpEntryPoint, [&](const auto& appender)
|
||||
{
|
||||
appender(execModel);
|
||||
appender(entryFunc.id);
|
||||
appender(statement->name);
|
||||
|
||||
for (const auto& [name, varData] : state.builtinIds)
|
||||
appender(varData.varId);
|
||||
|
||||
for (const auto& [name, varData] : state.inputIds)
|
||||
appender(varData.varId);
|
||||
|
||||
for (const auto& [name, varData] : state.outputIds)
|
||||
appender(varData.varId);
|
||||
});
|
||||
|
||||
if (stage == ShaderStageType::Fragment)
|
||||
state.header.Append(SpirvOp::OpExecutionMode, entryFunc.id, SpirvExecutionMode::OriginUpperLeft);*/
|
||||
}
|
||||
|
||||
for (auto&& [varId, builtin] : preVisitor.builtinDecorations)
|
||||
state.annotations.Append(SpirvOp::OpDecorate, varId, SpirvDecoration::BuiltIn, builtin);
|
||||
|
||||
for (auto&& [varId, location] : preVisitor.locationDecorations)
|
||||
state.annotations.Append(SpirvOp::OpDecorate, varId, SpirvDecoration::Location, location);
|
||||
|
||||
m_currentState->constantTypeCache.Write(m_currentState->annotations, m_currentState->constants, m_currentState->debugInfo);
|
||||
|
||||
std::vector<UInt32> ret;
|
||||
MergeSections(ret, state.header);
|
||||
MergeSections(ret, state.debugInfo);
|
||||
@@ -511,171 +517,53 @@ namespace Nz
|
||||
m_currentState->header.Append(SpirvOp::OpExtInstImport, resultId, extInst);
|
||||
|
||||
m_currentState->header.Append(SpirvOp::OpMemoryModel, SpirvAddressingModel::Logical, SpirvMemoryModel::GLSL450);
|
||||
}
|
||||
|
||||
UInt32 SpirvWriter::GetConstantId(const ShaderConstantValue& value) const
|
||||
{
|
||||
return m_currentState->constantTypeCache.GetId(*SpirvConstantCache::BuildConstant(value));
|
||||
}
|
||||
|
||||
UInt32 SpirvWriter::GetFunctionTypeId(const ShaderAst::DeclareFunctionStatement& functionNode)
|
||||
{
|
||||
return m_currentState->constantTypeCache.GetId({ *BuildFunctionType(functionNode) });
|
||||
}
|
||||
|
||||
auto SpirvWriter::GetBuiltinVariable(ShaderAst::BuiltinEntry builtin) const -> const ExtVar&
|
||||
{
|
||||
auto it = m_currentState->builtinIds.find(builtin);
|
||||
assert(it != m_currentState->builtinIds.end());
|
||||
|
||||
return it->second;
|
||||
}
|
||||
|
||||
auto SpirvWriter::GetInputVariable(const std::string& name) const -> const ExtVar&
|
||||
{
|
||||
auto it = m_currentState->inputIds.find(name);
|
||||
assert(it != m_currentState->inputIds.end());
|
||||
|
||||
return it->second;
|
||||
}
|
||||
|
||||
auto SpirvWriter::GetOutputVariable(const std::string& name) const -> const ExtVar&
|
||||
{
|
||||
auto it = m_currentState->outputIds.find(name);
|
||||
assert(it != m_currentState->outputIds.end());
|
||||
|
||||
return it->second;
|
||||
}
|
||||
|
||||
auto SpirvWriter::GetUniformVariable(const std::string& name) const -> const ExtVar&
|
||||
{
|
||||
auto it = m_currentState->uniformIds.find(name);
|
||||
assert(it != m_currentState->uniformIds.end());
|
||||
|
||||
return it.value();
|
||||
}
|
||||
|
||||
UInt32 SpirvWriter::GetPointerTypeId(const ShaderAst::ExpressionType& type, SpirvStorageClass storageClass) const
|
||||
{
|
||||
return m_currentState->constantTypeCache.GetId(*SpirvConstantCache::BuildPointerType(type, storageClass));
|
||||
}
|
||||
|
||||
UInt32 SpirvWriter::GetTypeId(const ShaderAst::ExpressionType& type) const
|
||||
{
|
||||
return m_currentState->constantTypeCache.GetId(*SpirvConstantCache::BuildType(type));
|
||||
}
|
||||
|
||||
UInt32 SpirvWriter::ReadInputVariable(const std::string& name)
|
||||
{
|
||||
auto it = m_currentState->inputIds.find(name);
|
||||
assert(it != m_currentState->inputIds.end());
|
||||
|
||||
return ReadVariable(it.value());
|
||||
}
|
||||
|
||||
std::optional<UInt32> SpirvWriter::ReadInputVariable(const std::string& name, OnlyCache)
|
||||
{
|
||||
auto it = m_currentState->inputIds.find(name);
|
||||
assert(it != m_currentState->inputIds.end());
|
||||
|
||||
return ReadVariable(it.value(), OnlyCache{});
|
||||
}
|
||||
|
||||
UInt32 SpirvWriter::ReadLocalVariable(const std::string& name)
|
||||
{
|
||||
auto it = m_currentState->varToResult.find(name);
|
||||
assert(it != m_currentState->varToResult.end());
|
||||
|
||||
return it->second;
|
||||
}
|
||||
|
||||
std::optional<UInt32> SpirvWriter::ReadLocalVariable(const std::string& name, OnlyCache)
|
||||
{
|
||||
auto it = m_currentState->varToResult.find(name);
|
||||
if (it == m_currentState->varToResult.end())
|
||||
return {};
|
||||
|
||||
return it->second;
|
||||
}
|
||||
|
||||
UInt32 SpirvWriter::ReadParameterVariable(const std::string& name)
|
||||
{
|
||||
auto it = m_currentState->parameterIds.find(name);
|
||||
assert(it != m_currentState->parameterIds.end());
|
||||
|
||||
return ReadVariable(it.value());
|
||||
}
|
||||
|
||||
std::optional<UInt32> SpirvWriter::ReadParameterVariable(const std::string& name, OnlyCache)
|
||||
{
|
||||
auto it = m_currentState->parameterIds.find(name);
|
||||
assert(it != m_currentState->parameterIds.end());
|
||||
|
||||
return ReadVariable(it.value(), OnlyCache{});
|
||||
}
|
||||
|
||||
UInt32 SpirvWriter::ReadUniformVariable(const std::string& name)
|
||||
{
|
||||
auto it = m_currentState->uniformIds.find(name);
|
||||
assert(it != m_currentState->uniformIds.end());
|
||||
|
||||
return ReadVariable(it.value());
|
||||
}
|
||||
|
||||
std::optional<UInt32> SpirvWriter::ReadUniformVariable(const std::string& name, OnlyCache)
|
||||
{
|
||||
auto it = m_currentState->uniformIds.find(name);
|
||||
assert(it != m_currentState->uniformIds.end());
|
||||
|
||||
return ReadVariable(it.value(), OnlyCache{});
|
||||
}
|
||||
|
||||
UInt32 SpirvWriter::ReadVariable(ExtVar& var)
|
||||
{
|
||||
if (!var.valueId.has_value())
|
||||
std::optional<UInt32> fragmentFuncId;
|
||||
for (auto& func : m_currentState->funcs)
|
||||
{
|
||||
UInt32 resultId = AllocateResultId();
|
||||
m_currentState->functionBlocks.back().Append(SpirvOp::OpLoad, var.typeId, resultId, var.varId);
|
||||
m_currentState->debugInfo.Append(SpirvOp::OpName, func.funcId, func.name);
|
||||
|
||||
var.valueId = resultId;
|
||||
if (func.entryPointData)
|
||||
{
|
||||
auto& entryPointData = func.entryPointData.value();
|
||||
|
||||
SpirvExecutionModel execModel;
|
||||
|
||||
switch (entryPointData.stageType)
|
||||
{
|
||||
case ShaderStageType::Fragment:
|
||||
execModel = SpirvExecutionModel::Fragment;
|
||||
break;
|
||||
|
||||
case ShaderStageType::Vertex:
|
||||
execModel = SpirvExecutionModel::Vertex;
|
||||
break;
|
||||
|
||||
default:
|
||||
throw std::runtime_error("not yet implemented");
|
||||
}
|
||||
|
||||
m_currentState->header.AppendVariadic(SpirvOp::OpEntryPoint, [&](const auto& appender)
|
||||
{
|
||||
appender(execModel);
|
||||
appender(func.funcId);
|
||||
appender(func.name);
|
||||
|
||||
for (const auto& input : entryPointData.inputs)
|
||||
appender(input.varId);
|
||||
|
||||
for (const auto& output : entryPointData.outputs)
|
||||
appender(output.varId);
|
||||
});
|
||||
|
||||
if (entryPointData.stageType == ShaderStageType::Fragment)
|
||||
fragmentFuncId = func.funcId;
|
||||
}
|
||||
}
|
||||
|
||||
return var.valueId.value();
|
||||
}
|
||||
if (fragmentFuncId)
|
||||
m_currentState->header.Append(SpirvOp::OpExecutionMode, *fragmentFuncId, SpirvExecutionMode::OriginUpperLeft);
|
||||
|
||||
std::optional<UInt32> SpirvWriter::ReadVariable(const ExtVar& var, OnlyCache)
|
||||
{
|
||||
if (!var.valueId.has_value())
|
||||
return {};
|
||||
|
||||
return var.valueId.value();
|
||||
}
|
||||
|
||||
UInt32 SpirvWriter::RegisterConstant(const ShaderConstantValue& value)
|
||||
{
|
||||
return m_currentState->constantTypeCache.Register(*SpirvConstantCache::BuildConstant(value));
|
||||
}
|
||||
|
||||
UInt32 SpirvWriter::RegisterFunctionType(const ShaderAst::DeclareFunctionStatement& functionNode)
|
||||
{
|
||||
return m_currentState->constantTypeCache.Register({ *BuildFunctionType(functionNode) });
|
||||
}
|
||||
|
||||
UInt32 SpirvWriter::RegisterPointerType(ShaderAst::ExpressionType type, SpirvStorageClass storageClass)
|
||||
{
|
||||
return m_currentState->constantTypeCache.Register(*SpirvConstantCache::BuildPointerType(type, storageClass));
|
||||
}
|
||||
|
||||
UInt32 SpirvWriter::RegisterType(ShaderAst::ExpressionType type)
|
||||
{
|
||||
assert(m_currentState);
|
||||
return m_currentState->constantTypeCache.Register(*SpirvConstantCache::BuildType(type));
|
||||
}
|
||||
|
||||
void SpirvWriter::WriteLocalVariable(std::string name, UInt32 resultId)
|
||||
{
|
||||
assert(m_currentState);
|
||||
m_currentState->varToResult.insert_or_assign(std::move(name), resultId);
|
||||
}
|
||||
|
||||
SpirvConstantCache::TypePtr SpirvWriter::BuildFunctionType(const ShaderAst::DeclareFunctionStatement& functionNode)
|
||||
@@ -686,7 +574,56 @@ namespace Nz
|
||||
for (const auto& parameter : functionNode.parameters)
|
||||
parameterTypes.push_back(parameter.type);
|
||||
|
||||
return SpirvConstantCache::BuildFunctionType(functionNode.returnType, parameterTypes);
|
||||
return m_currentState->constantTypeCache.BuildFunctionType(functionNode.returnType, parameterTypes);
|
||||
}
|
||||
|
||||
UInt32 SpirvWriter::GetConstantId(const ShaderAst::ConstantValue& value) const
|
||||
{
|
||||
return m_currentState->constantTypeCache.GetId(*m_currentState->constantTypeCache.BuildConstant(value));
|
||||
}
|
||||
|
||||
UInt32 SpirvWriter::GetExtVarPointerId(std::size_t extVarIndex) const
|
||||
{
|
||||
auto it = m_currentState->preVisitor->extVars.find(extVarIndex);
|
||||
assert(it != m_currentState->preVisitor->extVars.end());
|
||||
|
||||
return it->second.pointerId;
|
||||
}
|
||||
|
||||
UInt32 SpirvWriter::GetFunctionTypeId(const ShaderAst::DeclareFunctionStatement& functionNode)
|
||||
{
|
||||
return m_currentState->constantTypeCache.GetId({ *BuildFunctionType(functionNode) });
|
||||
}
|
||||
|
||||
UInt32 SpirvWriter::GetPointerTypeId(const ShaderAst::ExpressionType& type, SpirvStorageClass storageClass) const
|
||||
{
|
||||
return m_currentState->constantTypeCache.GetId(*m_currentState->constantTypeCache.BuildPointerType(type, storageClass));
|
||||
}
|
||||
|
||||
UInt32 SpirvWriter::GetTypeId(const ShaderAst::ExpressionType& type) const
|
||||
{
|
||||
return m_currentState->constantTypeCache.GetId(*m_currentState->constantTypeCache.BuildType(type));
|
||||
}
|
||||
|
||||
UInt32 SpirvWriter::RegisterConstant(const ShaderAst::ConstantValue& value)
|
||||
{
|
||||
return m_currentState->constantTypeCache.Register(*m_currentState->constantTypeCache.BuildConstant(value));
|
||||
}
|
||||
|
||||
UInt32 SpirvWriter::RegisterFunctionType(const ShaderAst::DeclareFunctionStatement& functionNode)
|
||||
{
|
||||
return m_currentState->constantTypeCache.Register({ *BuildFunctionType(functionNode) });
|
||||
}
|
||||
|
||||
UInt32 SpirvWriter::RegisterPointerType(ShaderAst::ExpressionType type, SpirvStorageClass storageClass)
|
||||
{
|
||||
return m_currentState->constantTypeCache.Register(*m_currentState->constantTypeCache.BuildPointerType(type, storageClass));
|
||||
}
|
||||
|
||||
UInt32 SpirvWriter::RegisterType(ShaderAst::ExpressionType type)
|
||||
{
|
||||
assert(m_currentState);
|
||||
return m_currentState->constantTypeCache.Register(*m_currentState->constantTypeCache.BuildType(type));
|
||||
}
|
||||
|
||||
void SpirvWriter::MergeSections(std::vector<UInt32>& output, const SpirvSection& from)
|
||||
|
||||
Reference in New Issue
Block a user