Shader: First working version on both Vulkan & OpenGL (ES)
This commit is contained in:
@@ -12,10 +12,12 @@
|
||||
#include <Nazara/Shader/SpirvConstantCache.hpp>
|
||||
#include <Nazara/Shader/SpirvData.hpp>
|
||||
#include <Nazara/Shader/SpirvSection.hpp>
|
||||
#include <Nazara/Shader/Ast/TransformVisitor.hpp>
|
||||
#include <tsl/ordered_map.h>
|
||||
#include <tsl/ordered_set.h>
|
||||
#include <SpirV/GLSL.std.450.h>
|
||||
#include <cassert>
|
||||
#include <map>
|
||||
#include <stdexcept>
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
@@ -25,34 +27,61 @@ namespace Nz
|
||||
{
|
||||
namespace
|
||||
{
|
||||
//FIXME: Have this only once
|
||||
std::unordered_map<std::string, ShaderStageType> s_entryPoints = {
|
||||
{ "frag", ShaderStageType::Fragment },
|
||||
{ "vert", ShaderStageType::Vertex },
|
||||
};
|
||||
|
||||
struct Builtin
|
||||
{
|
||||
const char* debugName;
|
||||
ShaderStageTypeFlags compatibleStages;
|
||||
SpirvBuiltIn decoration;
|
||||
};
|
||||
|
||||
std::unordered_map<std::string, Builtin> s_builtinMapping = {
|
||||
{ "position", { "VertexPosition", ShaderStageType::Vertex, SpirvBuiltIn::Position } }
|
||||
};
|
||||
|
||||
class PreVisitor : public ShaderAst::AstScopedVisitor
|
||||
{
|
||||
public:
|
||||
struct UniformVar
|
||||
{
|
||||
std::optional<UInt32> bindingIndex;
|
||||
UInt32 pointerId;
|
||||
};
|
||||
|
||||
using BuiltinDecoration = std::map<UInt32, SpirvBuiltIn>;
|
||||
using LocationDecoration = std::map<UInt32, UInt32>;
|
||||
using ExtInstList = std::unordered_set<std::string>;
|
||||
using ExtVarContainer = std::unordered_map<std::size_t /*varIndex*/, UniformVar>;
|
||||
using LocalContainer = std::unordered_set<ShaderAst::ExpressionType>;
|
||||
using FunctionContainer = std::vector<std::reference_wrapper<ShaderAst::DeclareFunctionStatement>>;
|
||||
using StructContainer = std::vector<ShaderAst::StructDescription>;
|
||||
|
||||
PreVisitor(const SpirvWriter::States& conditions, SpirvConstantCache& constantCache) :
|
||||
PreVisitor(const SpirvWriter::States& conditions, SpirvConstantCache& constantCache, std::vector<SpirvAstVisitor::FuncData>& funcs) :
|
||||
m_conditions(conditions),
|
||||
m_constantCache(constantCache)
|
||||
m_constantCache(constantCache),
|
||||
m_externalBlockIndex(0),
|
||||
m_funcs(funcs)
|
||||
{
|
||||
m_constantCache.SetIdentifierCallback([&](const std::string& identifierName)
|
||||
m_constantCache.SetStructCallback([this](std::size_t structIndex) -> const ShaderAst::StructDescription&
|
||||
{
|
||||
const Identifier* identifier = FindIdentifier(identifierName);
|
||||
if (!identifier)
|
||||
throw std::runtime_error("invalid identifier " + identifierName);
|
||||
|
||||
assert(std::holds_alternative<ShaderAst::StructDescription>(identifier->value));
|
||||
return SpirvConstantCache::BuildType(std::get<ShaderAst::StructDescription>(identifier->value));
|
||||
assert(structIndex < declaredStructs.size());
|
||||
return declaredStructs[structIndex];
|
||||
});
|
||||
}
|
||||
|
||||
void Visit(ShaderAst::AccessMemberExpression& node) override
|
||||
void Visit(ShaderAst::AccessMemberIndexExpression& node) override
|
||||
{
|
||||
/*for (std::size_t index : node.memberIdentifiers)
|
||||
m_constantCache.Register(*SpirvConstantCache::BuildConstant(Int32(index)));*/
|
||||
|
||||
AstRecursiveVisitor::Visit(node);
|
||||
|
||||
for (std::size_t index : node.memberIndices)
|
||||
m_constantCache.Register(*m_constantCache.BuildConstant(Int32(index)));
|
||||
|
||||
m_constantCache.Register(*m_constantCache.BuildType(node.cachedExpressionType.value()));
|
||||
}
|
||||
|
||||
void Visit(ShaderAst::ConditionalExpression& node) override
|
||||
@@ -64,6 +93,8 @@ namespace Nz
|
||||
Visit(node.truePath);
|
||||
else
|
||||
Visit(node.falsePath);*/
|
||||
|
||||
m_constantCache.Register(*m_constantCache.BuildType(node.cachedExpressionType.value()));
|
||||
}
|
||||
|
||||
void Visit(ShaderAst::ConditionalStatement& node) override
|
||||
@@ -79,52 +110,189 @@ namespace Nz
|
||||
{
|
||||
std::visit([&](auto&& arg)
|
||||
{
|
||||
m_constantCache.Register(*SpirvConstantCache::BuildConstant(arg));
|
||||
m_constantCache.Register(*m_constantCache.BuildConstant(arg));
|
||||
}, node.value);
|
||||
|
||||
AstScopedVisitor::Visit(node);
|
||||
}
|
||||
|
||||
void Visit(ShaderAst::DeclareExternalStatement& node) override
|
||||
{
|
||||
assert(node.varIndex);
|
||||
std::size_t varIndex = *node.varIndex;
|
||||
for (auto& extVar : node.externalVars)
|
||||
{
|
||||
SpirvConstantCache::Variable variable;
|
||||
variable.debugName = extVar.name;
|
||||
variable.storageClass = (ShaderAst::IsSamplerType(extVar.type)) ? SpirvStorageClass::UniformConstant : SpirvStorageClass::Uniform;
|
||||
variable.type = m_constantCache.BuildPointerType(extVar.type, variable.storageClass);
|
||||
|
||||
UniformVar& uniformVar = extVars[varIndex++];
|
||||
uniformVar.pointerId = m_constantCache.Register(variable);
|
||||
|
||||
for (const auto& [attributeType, attributeParam] : extVar.attributes)
|
||||
{
|
||||
if (attributeType == ShaderAst::AttributeType::Binding)
|
||||
{
|
||||
uniformVar.bindingIndex = std::get<long long>(attributeParam);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void Visit(ShaderAst::DeclareFunctionStatement& node) override
|
||||
{
|
||||
funcs.emplace_back(node);
|
||||
std::optional<ShaderStageType> entryPointType;
|
||||
for (auto& attribute : node.attributes)
|
||||
{
|
||||
if (attribute.type == ShaderAst::AttributeType::Entry)
|
||||
{
|
||||
auto it = s_entryPoints.find(std::get<std::string>(attribute.args));
|
||||
assert(it != s_entryPoints.end());
|
||||
|
||||
std::vector<ShaderAst::ExpressionType> parameterTypes;
|
||||
for (auto& parameter : node.parameters)
|
||||
parameterTypes.push_back(parameter.type);
|
||||
entryPointType = it->second;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
m_constantCache.Register(*SpirvConstantCache::BuildFunctionType(node.returnType, parameterTypes));
|
||||
assert(node.funcIndex);
|
||||
std::size_t funcIndex = *node.funcIndex;
|
||||
|
||||
if (funcIndex >= m_funcs.size())
|
||||
m_funcs.resize(funcIndex + 1);
|
||||
|
||||
auto& funcData = m_funcs[funcIndex];
|
||||
funcData.name = node.name;
|
||||
|
||||
if (!entryPointType)
|
||||
{
|
||||
std::vector<ShaderAst::ExpressionType> parameterTypes;
|
||||
for (auto& parameter : node.parameters)
|
||||
parameterTypes.push_back(parameter.type);
|
||||
|
||||
funcData.returnTypeId = m_constantCache.Register(*m_constantCache.BuildType(node.returnType));
|
||||
funcData.funcTypeId = m_constantCache.Register(*m_constantCache.BuildFunctionType(node.returnType, parameterTypes));
|
||||
|
||||
for (auto& parameter : node.parameters)
|
||||
{
|
||||
auto& funcParam = funcData.parameters.emplace_back();
|
||||
funcParam.pointerTypeId = m_constantCache.Register(*m_constantCache.BuildPointerType(parameter.type, SpirvStorageClass::Function));
|
||||
funcParam.typeId = m_constantCache.Register(*m_constantCache.BuildType(parameter.type));
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
using EntryPoint = SpirvAstVisitor::EntryPoint;
|
||||
|
||||
funcData.returnTypeId = m_constantCache.Register(*m_constantCache.BuildType(ShaderAst::NoType{}));
|
||||
funcData.funcTypeId = m_constantCache.Register(*m_constantCache.BuildFunctionType(ShaderAst::NoType{}, {}));
|
||||
|
||||
std::optional<EntryPoint::InputStruct> inputStruct;
|
||||
std::vector<EntryPoint::Input> inputs;
|
||||
if (!node.parameters.empty())
|
||||
{
|
||||
assert(node.parameters.size() == 1);
|
||||
auto& parameter = node.parameters.front();
|
||||
assert(std::holds_alternative<ShaderAst::StructType>(parameter.type));
|
||||
|
||||
std::size_t structIndex = std::get<ShaderAst::StructType>(parameter.type).structIndex;
|
||||
const ShaderAst::StructDescription& structDesc = declaredStructs[structIndex];
|
||||
|
||||
std::size_t memberIndex = 0;
|
||||
for (const auto& member : structDesc.members)
|
||||
{
|
||||
if (UInt32 varId = HandleEntryInOutType(*entryPointType, funcIndex, member, SpirvStorageClass::Input); varId != 0)
|
||||
{
|
||||
inputs.push_back({
|
||||
m_constantCache.Register(*m_constantCache.BuildConstant(Int32(memberIndex))),
|
||||
m_constantCache.Register(*m_constantCache.BuildPointerType(member.type, SpirvStorageClass::Function)),
|
||||
varId
|
||||
});
|
||||
}
|
||||
|
||||
memberIndex++;
|
||||
}
|
||||
|
||||
inputStruct = EntryPoint::InputStruct{
|
||||
m_constantCache.Register(*m_constantCache.BuildPointerType(parameter.type, SpirvStorageClass::Function)),
|
||||
m_constantCache.Register(*m_constantCache.BuildType(parameter.type))
|
||||
};
|
||||
}
|
||||
|
||||
std::optional<UInt32> outputStructId;
|
||||
std::vector<EntryPoint::Output> outputs;
|
||||
if (!IsNoType(node.returnType))
|
||||
{
|
||||
assert(std::holds_alternative<ShaderAst::StructType>(node.returnType));
|
||||
|
||||
std::size_t structIndex = std::get<ShaderAst::StructType>(node.returnType).structIndex;
|
||||
const ShaderAst::StructDescription& structDesc = declaredStructs[structIndex];
|
||||
|
||||
std::size_t memberIndex = 0;
|
||||
for (const auto& member : structDesc.members)
|
||||
{
|
||||
if (UInt32 varId = HandleEntryInOutType(*entryPointType, funcIndex, member, SpirvStorageClass::Output); varId != 0)
|
||||
{
|
||||
outputs.push_back({
|
||||
Int32(memberIndex),
|
||||
m_constantCache.Register(*m_constantCache.BuildType(member.type)),
|
||||
varId
|
||||
});
|
||||
}
|
||||
|
||||
memberIndex++;
|
||||
}
|
||||
|
||||
outputStructId = m_constantCache.Register(*m_constantCache.BuildType(node.returnType));
|
||||
}
|
||||
|
||||
funcData.entryPointData = EntryPoint{
|
||||
*entryPointType,
|
||||
inputStruct,
|
||||
outputStructId,
|
||||
funcIndex,
|
||||
std::move(inputs),
|
||||
std::move(outputs)
|
||||
};
|
||||
}
|
||||
|
||||
m_funcIndex = funcIndex;
|
||||
AstScopedVisitor::Visit(node);
|
||||
m_funcIndex.reset();
|
||||
}
|
||||
|
||||
void Visit(ShaderAst::DeclareStructStatement& node) override
|
||||
{
|
||||
AstScopedVisitor::Visit(node);
|
||||
|
||||
SpirvConstantCache::Structure sType;
|
||||
sType.name = node.description.name;
|
||||
assert(node.structIndex);
|
||||
std::size_t structIndex = *node.structIndex;
|
||||
if (structIndex >= declaredStructs.size())
|
||||
declaredStructs.resize(structIndex + 1);
|
||||
|
||||
for (const auto& [name, attribute, type] : node.description.members)
|
||||
{
|
||||
auto& sMembers = sType.members.emplace_back();
|
||||
sMembers.name = name;
|
||||
sMembers.type = SpirvConstantCache::BuildType(type);
|
||||
}
|
||||
declaredStructs[structIndex] = node.description;
|
||||
|
||||
m_constantCache.Register(SpirvConstantCache::Type{ std::move(sType) });
|
||||
m_constantCache.Register(*m_constantCache.BuildType(node.description));
|
||||
}
|
||||
|
||||
void Visit(ShaderAst::DeclareVariableStatement& node) override
|
||||
{
|
||||
AstScopedVisitor::Visit(node);
|
||||
|
||||
m_constantCache.Register(*SpirvConstantCache::BuildType(node.varType));
|
||||
assert(m_funcIndex);
|
||||
auto& func = m_funcs[*m_funcIndex];
|
||||
|
||||
assert(node.varIndex);
|
||||
func.varIndexToVarId[*node.varIndex] = func.variables.size();
|
||||
|
||||
auto& var = func.variables.emplace_back();
|
||||
var.typeId = m_constantCache.Register(*m_constantCache.BuildPointerType(node.varType, SpirvStorageClass::Function));
|
||||
}
|
||||
|
||||
void Visit(ShaderAst::IdentifierExpression& node) override
|
||||
{
|
||||
m_constantCache.Register(*SpirvConstantCache::BuildType(node.cachedExpressionType.value()));
|
||||
m_constantCache.Register(*m_constantCache.BuildType(node.cachedExpressionType.value()));
|
||||
|
||||
AstScopedVisitor::Visit(node);
|
||||
}
|
||||
@@ -144,40 +312,88 @@ namespace Nz
|
||||
case ShaderAst::IntrinsicType::DotProduct:
|
||||
break;
|
||||
}
|
||||
|
||||
m_constantCache.Register(*m_constantCache.BuildType(node.cachedExpressionType.value()));
|
||||
}
|
||||
|
||||
void Visit(ShaderAst::SwizzleExpression& node) override
|
||||
{
|
||||
AstScopedVisitor::Visit(node);
|
||||
|
||||
m_constantCache.Register(*m_constantCache.BuildType(node.cachedExpressionType.value()));
|
||||
}
|
||||
|
||||
UInt32 HandleEntryInOutType(ShaderStageType entryPointType, std::size_t funcIndex, const ShaderAst::StructDescription::StructMember& member, SpirvStorageClass storageClass)
|
||||
{
|
||||
std::optional<std::reference_wrapper<Builtin>> builtinOpt;
|
||||
std::optional<long long> attributeLocation;
|
||||
for (const auto& [attributeType, attributeParam] : member.attributes)
|
||||
{
|
||||
if (attributeType == ShaderAst::AttributeType::Builtin)
|
||||
{
|
||||
auto it = s_builtinMapping.find(std::get<std::string>(attributeParam));
|
||||
if (it != s_builtinMapping.end())
|
||||
{
|
||||
builtinOpt = it->second;
|
||||
break;
|
||||
}
|
||||
}
|
||||
else if (attributeType == ShaderAst::AttributeType::Location)
|
||||
{
|
||||
attributeLocation = std::get<long long>(attributeParam);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (builtinOpt)
|
||||
{
|
||||
Builtin& builtin = *builtinOpt;
|
||||
if ((builtin.compatibleStages & entryPointType) == 0)
|
||||
return 0;
|
||||
|
||||
SpirvBuiltIn builtinDecoration = builtin.decoration;
|
||||
|
||||
SpirvConstantCache::Variable variable;
|
||||
variable.debugName = builtin.debugName;
|
||||
variable.funcId = funcIndex;
|
||||
variable.storageClass = storageClass;
|
||||
variable.type = m_constantCache.BuildPointerType(member.type, storageClass);
|
||||
|
||||
UInt32 varId = m_constantCache.Register(variable);
|
||||
builtinDecorations[varId] = builtinDecoration;
|
||||
|
||||
return varId;
|
||||
}
|
||||
else if (attributeLocation)
|
||||
{
|
||||
SpirvConstantCache::Variable variable;
|
||||
variable.debugName = member.name;
|
||||
variable.funcId = funcIndex;
|
||||
variable.storageClass = storageClass;
|
||||
variable.type = m_constantCache.BuildPointerType(member.type, storageClass);
|
||||
|
||||
UInt32 varId = m_constantCache.Register(variable);
|
||||
locationDecorations[varId] = *attributeLocation;
|
||||
|
||||
return varId;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
BuiltinDecoration builtinDecorations;
|
||||
ExtInstList extInsts;
|
||||
FunctionContainer funcs;
|
||||
ExtVarContainer extVars;
|
||||
LocationDecoration locationDecorations;
|
||||
StructContainer declaredStructs;
|
||||
|
||||
private:
|
||||
const SpirvWriter::States& m_conditions;
|
||||
SpirvConstantCache& m_constantCache;
|
||||
std::optional<std::size_t> m_funcIndex;
|
||||
std::size_t m_externalBlockIndex;
|
||||
std::vector<SpirvAstVisitor::FuncData>& m_funcs;
|
||||
};
|
||||
|
||||
template<typename T>
|
||||
constexpr ShaderAst::PrimitiveType GetBasicType()
|
||||
{
|
||||
if constexpr (std::is_same_v<T, bool>)
|
||||
return ShaderAst::PrimitiveType::Boolean;
|
||||
else if constexpr (std::is_same_v<T, float>)
|
||||
return(ShaderAst::PrimitiveType::Float32);
|
||||
else if constexpr (std::is_same_v<T, Int32>)
|
||||
return(ShaderAst::PrimitiveType::Int32);
|
||||
else if constexpr (std::is_same_v<T, Vector2f>)
|
||||
return(ShaderAst::PrimitiveType::Float2);
|
||||
else if constexpr (std::is_same_v<T, Vector3f>)
|
||||
return(ShaderAst::PrimitiveType::Float3);
|
||||
else if constexpr (std::is_same_v<T, Vector4f>)
|
||||
return(ShaderAst::PrimitiveType::Float4);
|
||||
else if constexpr (std::is_same_v<T, Vector2i32>)
|
||||
return(ShaderAst::PrimitiveType::Int2);
|
||||
else if constexpr (std::is_same_v<T, Vector3i32>)
|
||||
return(ShaderAst::PrimitiveType::Int3);
|
||||
else if constexpr (std::is_same_v<T, Vector4i32>)
|
||||
return(ShaderAst::PrimitiveType::Int4);
|
||||
else
|
||||
static_assert(AlwaysFalse<T>::value, "unhandled type");
|
||||
}
|
||||
}
|
||||
|
||||
struct SpirvWriter::State
|
||||
@@ -194,18 +410,13 @@ namespace Nz
|
||||
UInt32 id;
|
||||
};
|
||||
|
||||
tsl::ordered_map<std::string, ExtVar> inputIds;
|
||||
tsl::ordered_map<std::string, ExtVar> outputIds;
|
||||
tsl::ordered_map<std::string, ExtVar> parameterIds;
|
||||
tsl::ordered_map<std::string, ExtVar> uniformIds;
|
||||
std::unordered_map<std::string, UInt32> extensionInstructions;
|
||||
std::unordered_map<ShaderAst::BuiltinEntry, ExtVar> builtinIds;
|
||||
std::unordered_map<std::string, UInt32> varToResult;
|
||||
std::vector<Func> funcs;
|
||||
std::vector<SpirvBlock> functionBlocks;
|
||||
std::vector<SpirvAstVisitor::FuncData> funcs;
|
||||
std::vector<UInt32> resultIds;
|
||||
UInt32 nextVarIndex = 1;
|
||||
SpirvConstantCache constantTypeCache; //< init after nextVarIndex
|
||||
PreVisitor* preVisitor;
|
||||
|
||||
// Output
|
||||
SpirvSection header;
|
||||
@@ -226,6 +437,9 @@ namespace Nz
|
||||
if (!ShaderAst::ValidateAst(shader, &error))
|
||||
throw std::runtime_error("Invalid shader AST: " + error);
|
||||
|
||||
ShaderAst::TransformVisitor transformVisitor;
|
||||
ShaderAst::StatementPtr transformedShader = transformVisitor.Transform(shader);
|
||||
|
||||
m_context.states = &conditions;
|
||||
|
||||
State state;
|
||||
@@ -235,245 +449,37 @@ namespace Nz
|
||||
m_currentState = nullptr;
|
||||
});
|
||||
|
||||
ShaderAst::AstCloner cloner;
|
||||
|
||||
// Register all extended instruction sets
|
||||
PreVisitor preVisitor(conditions, state.constantTypeCache);
|
||||
shader->Visit(preVisitor);
|
||||
PreVisitor preVisitor(conditions, state.constantTypeCache, state.funcs);
|
||||
transformedShader->Visit(preVisitor);
|
||||
|
||||
m_currentState->preVisitor = &preVisitor;
|
||||
|
||||
for (const std::string& extInst : preVisitor.extInsts)
|
||||
state.extensionInstructions[extInst] = AllocateResultId();
|
||||
|
||||
// Register all types
|
||||
/*for (const auto& func : shader.GetFunctions())
|
||||
{
|
||||
RegisterType(func.returnType);
|
||||
for (const auto& param : func.parameters)
|
||||
RegisterType(param.type);
|
||||
}
|
||||
|
||||
for (const auto& input : shader.GetInputs())
|
||||
RegisterPointerType(input.type, SpirvStorageClass::Input);
|
||||
|
||||
for (const auto& output : shader.GetOutputs())
|
||||
RegisterPointerType(output.type, SpirvStorageClass::Output);
|
||||
|
||||
for (const auto& uniform : shader.GetUniforms())
|
||||
RegisterPointerType(uniform.type, (IsSamplerType(uniform.type)) ? SpirvStorageClass::UniformConstant : SpirvStorageClass::Uniform);
|
||||
|
||||
for (const auto& func : shader.GetFunctions())
|
||||
RegisterFunctionType(func.returnType, func.parameters);
|
||||
|
||||
for (const auto& type : preVisitor.variableTypes)
|
||||
RegisterType(type);
|
||||
|
||||
for (const auto& builtin : preVisitor.builtinVars)
|
||||
RegisterType(builtin->type);
|
||||
|
||||
// Register result id and debug infos for global variables/functions
|
||||
for (const auto& builtin : preVisitor.builtinVars)
|
||||
{
|
||||
SpirvConstantCache::Variable variable;
|
||||
SpirvBuiltIn builtinDecoration;
|
||||
switch (builtin->entry)
|
||||
{
|
||||
case ShaderAst::BuiltinEntry::VertexPosition:
|
||||
variable.debugName = "builtin_VertexPosition";
|
||||
variable.storageClass = SpirvStorageClass::Output;
|
||||
|
||||
builtinDecoration = SpirvBuiltIn::Position;
|
||||
break;
|
||||
|
||||
default:
|
||||
throw std::runtime_error("unexpected builtin type");
|
||||
}
|
||||
|
||||
const ShaderAst::ShaderExpressionType& builtinExprType = builtin->type;
|
||||
assert(IsBasicType(builtinExprType));
|
||||
|
||||
ShaderAst::BasicType builtinType = std::get<ShaderAst::BasicType>(builtinExprType);
|
||||
|
||||
variable.type = SpirvConstantCache::BuildPointerType(builtinType, variable.storageClass);
|
||||
|
||||
UInt32 varId = m_currentState->constantTypeCache.Register(variable);
|
||||
|
||||
ExtVar builtinData;
|
||||
builtinData.pointerTypeId = GetPointerTypeId(builtinType, variable.storageClass);
|
||||
builtinData.typeId = GetTypeId(builtinType);
|
||||
builtinData.varId = varId;
|
||||
|
||||
state.annotations.Append(SpirvOp::OpDecorate, builtinData.varId, SpirvDecoration::BuiltIn, builtinDecoration);
|
||||
|
||||
state.builtinIds.emplace(builtin->entry, builtinData);
|
||||
}
|
||||
|
||||
for (const auto& input : shader.GetInputs())
|
||||
{
|
||||
SpirvConstantCache::Variable variable;
|
||||
variable.debugName = input.name;
|
||||
variable.storageClass = SpirvStorageClass::Input;
|
||||
variable.type = SpirvConstantCache::BuildPointerType(shader, input.type, variable.storageClass);
|
||||
|
||||
UInt32 varId = m_currentState->constantTypeCache.Register(variable);
|
||||
|
||||
ExtVar inputData;
|
||||
inputData.pointerTypeId = GetPointerTypeId(input.type, variable.storageClass);
|
||||
inputData.typeId = GetTypeId(input.type);
|
||||
inputData.varId = varId;
|
||||
|
||||
state.inputIds.emplace(input.name, std::move(inputData));
|
||||
|
||||
if (input.locationIndex)
|
||||
state.annotations.Append(SpirvOp::OpDecorate, varId, SpirvDecoration::Location, *input.locationIndex);
|
||||
}
|
||||
|
||||
for (const auto& output : shader.GetOutputs())
|
||||
{
|
||||
SpirvConstantCache::Variable variable;
|
||||
variable.debugName = output.name;
|
||||
variable.storageClass = SpirvStorageClass::Output;
|
||||
variable.type = SpirvConstantCache::BuildPointerType(shader, output.type, variable.storageClass);
|
||||
|
||||
UInt32 varId = m_currentState->constantTypeCache.Register(variable);
|
||||
|
||||
ExtVar outputData;
|
||||
outputData.pointerTypeId = GetPointerTypeId(output.type, variable.storageClass);
|
||||
outputData.typeId = GetTypeId(output.type);
|
||||
outputData.varId = varId;
|
||||
|
||||
state.outputIds.emplace(output.name, std::move(outputData));
|
||||
|
||||
if (output.locationIndex)
|
||||
state.annotations.Append(SpirvOp::OpDecorate, varId, SpirvDecoration::Location, *output.locationIndex);
|
||||
}
|
||||
|
||||
for (const auto& uniform : shader.GetUniforms())
|
||||
{
|
||||
SpirvConstantCache::Variable variable;
|
||||
variable.debugName = uniform.name;
|
||||
variable.storageClass = (IsSamplerType(uniform.type)) ? SpirvStorageClass::UniformConstant : SpirvStorageClass::Uniform;
|
||||
variable.type = SpirvConstantCache::BuildPointerType(shader, uniform.type, variable.storageClass);
|
||||
|
||||
UInt32 varId = m_currentState->constantTypeCache.Register(variable);
|
||||
|
||||
ExtVar uniformData;
|
||||
uniformData.pointerTypeId = GetPointerTypeId(uniform.type, variable.storageClass);
|
||||
uniformData.typeId = GetTypeId(uniform.type);
|
||||
uniformData.varId = varId;
|
||||
|
||||
state.uniformIds.emplace(uniform.name, std::move(uniformData));
|
||||
|
||||
if (uniform.bindingIndex)
|
||||
{
|
||||
state.annotations.Append(SpirvOp::OpDecorate, varId, SpirvDecoration::Binding, *uniform.bindingIndex);
|
||||
state.annotations.Append(SpirvOp::OpDecorate, varId, SpirvDecoration::DescriptorSet, 0);
|
||||
}
|
||||
}*/
|
||||
|
||||
for (const ShaderAst::DeclareFunctionStatement& func : preVisitor.funcs)
|
||||
{
|
||||
auto& funcData = state.funcs.emplace_back();
|
||||
funcData.statement = &func;
|
||||
funcData.id = AllocateResultId();
|
||||
funcData.typeId = GetFunctionTypeId(func);
|
||||
|
||||
state.debugInfo.Append(SpirvOp::OpName, funcData.id, func.name);
|
||||
}
|
||||
|
||||
std::size_t funcIndex = 0;
|
||||
|
||||
for (const ShaderAst::DeclareFunctionStatement& func : preVisitor.funcs)
|
||||
{
|
||||
auto& funcData = state.funcs[funcIndex++];
|
||||
|
||||
state.instructions.Append(SpirvOp::OpFunction, GetTypeId(func.returnType), funcData.id, 0, funcData.typeId);
|
||||
|
||||
state.functionBlocks.clear();
|
||||
state.functionBlocks.emplace_back(*this);
|
||||
|
||||
state.parameterIds.clear();
|
||||
|
||||
for (const auto& param : func.parameters)
|
||||
{
|
||||
UInt32 paramResultId = AllocateResultId();
|
||||
state.instructions.Append(SpirvOp::OpFunctionParameter, GetTypeId(param.type), paramResultId);
|
||||
|
||||
ExtVar parameterData;
|
||||
parameterData.pointerTypeId = GetPointerTypeId(param.type, SpirvStorageClass::Function);
|
||||
parameterData.typeId = GetTypeId(param.type);
|
||||
parameterData.varId = paramResultId;
|
||||
|
||||
state.parameterIds.emplace(param.name, std::move(parameterData));
|
||||
}
|
||||
|
||||
SpirvAstVisitor visitor(*this, state.functionBlocks);
|
||||
for (const auto& statement : func.statements)
|
||||
statement->Visit(visitor);
|
||||
|
||||
if (!state.functionBlocks.back().IsTerminated())
|
||||
{
|
||||
assert(func.returnType == ShaderAst::ExpressionType{ ShaderAst::NoType{} });
|
||||
state.functionBlocks.back().Append(SpirvOp::OpReturn);
|
||||
}
|
||||
|
||||
for (SpirvBlock& block : state.functionBlocks)
|
||||
state.instructions.AppendSection(block);
|
||||
|
||||
state.instructions.Append(SpirvOp::OpFunctionEnd);
|
||||
}
|
||||
|
||||
m_currentState->constantTypeCache.Write(m_currentState->annotations, m_currentState->constants, m_currentState->debugInfo);
|
||||
SpirvAstVisitor visitor(*this, state.instructions, state.funcs);
|
||||
transformedShader->Visit(visitor);
|
||||
|
||||
AppendHeader();
|
||||
|
||||
for (std::size_t i = 0; i < ShaderStageTypeCount; ++i)
|
||||
for (auto&& [varIndex, extVar] : preVisitor.extVars)
|
||||
{
|
||||
/*const ShaderAst::DeclareFunctionStatement* statement = m_context.cache.entryFunctions[i];
|
||||
if (!statement)
|
||||
continue;
|
||||
|
||||
auto it = std::find_if(state.funcs.begin(), state.funcs.end(), [&](const auto& funcData) { return funcData.statement == statement; });
|
||||
assert(it != state.funcs.end());
|
||||
|
||||
const auto& entryFunc = *it;
|
||||
|
||||
SpirvExecutionModel execModel;
|
||||
|
||||
ShaderStageType stage = static_cast<ShaderStageType>(i);
|
||||
switch (stage)
|
||||
if (extVar.bindingIndex)
|
||||
{
|
||||
case ShaderStageType::Fragment:
|
||||
execModel = SpirvExecutionModel::Fragment;
|
||||
break;
|
||||
|
||||
case ShaderStageType::Vertex:
|
||||
execModel = SpirvExecutionModel::Vertex;
|
||||
break;
|
||||
|
||||
default:
|
||||
throw std::runtime_error("not yet implemented");
|
||||
state.annotations.Append(SpirvOp::OpDecorate, extVar.pointerId, SpirvDecoration::Binding, *extVar.bindingIndex);
|
||||
state.annotations.Append(SpirvOp::OpDecorate, extVar.pointerId, SpirvDecoration::DescriptorSet, 0);
|
||||
}
|
||||
|
||||
state.header.AppendVariadic(SpirvOp::OpEntryPoint, [&](const auto& appender)
|
||||
{
|
||||
appender(execModel);
|
||||
appender(entryFunc.id);
|
||||
appender(statement->name);
|
||||
|
||||
for (const auto& [name, varData] : state.builtinIds)
|
||||
appender(varData.varId);
|
||||
|
||||
for (const auto& [name, varData] : state.inputIds)
|
||||
appender(varData.varId);
|
||||
|
||||
for (const auto& [name, varData] : state.outputIds)
|
||||
appender(varData.varId);
|
||||
});
|
||||
|
||||
if (stage == ShaderStageType::Fragment)
|
||||
state.header.Append(SpirvOp::OpExecutionMode, entryFunc.id, SpirvExecutionMode::OriginUpperLeft);*/
|
||||
}
|
||||
|
||||
for (auto&& [varId, builtin] : preVisitor.builtinDecorations)
|
||||
state.annotations.Append(SpirvOp::OpDecorate, varId, SpirvDecoration::BuiltIn, builtin);
|
||||
|
||||
for (auto&& [varId, location] : preVisitor.locationDecorations)
|
||||
state.annotations.Append(SpirvOp::OpDecorate, varId, SpirvDecoration::Location, location);
|
||||
|
||||
m_currentState->constantTypeCache.Write(m_currentState->annotations, m_currentState->constants, m_currentState->debugInfo);
|
||||
|
||||
std::vector<UInt32> ret;
|
||||
MergeSections(ret, state.header);
|
||||
MergeSections(ret, state.debugInfo);
|
||||
@@ -511,171 +517,53 @@ namespace Nz
|
||||
m_currentState->header.Append(SpirvOp::OpExtInstImport, resultId, extInst);
|
||||
|
||||
m_currentState->header.Append(SpirvOp::OpMemoryModel, SpirvAddressingModel::Logical, SpirvMemoryModel::GLSL450);
|
||||
}
|
||||
|
||||
UInt32 SpirvWriter::GetConstantId(const ShaderConstantValue& value) const
|
||||
{
|
||||
return m_currentState->constantTypeCache.GetId(*SpirvConstantCache::BuildConstant(value));
|
||||
}
|
||||
|
||||
UInt32 SpirvWriter::GetFunctionTypeId(const ShaderAst::DeclareFunctionStatement& functionNode)
|
||||
{
|
||||
return m_currentState->constantTypeCache.GetId({ *BuildFunctionType(functionNode) });
|
||||
}
|
||||
|
||||
auto SpirvWriter::GetBuiltinVariable(ShaderAst::BuiltinEntry builtin) const -> const ExtVar&
|
||||
{
|
||||
auto it = m_currentState->builtinIds.find(builtin);
|
||||
assert(it != m_currentState->builtinIds.end());
|
||||
|
||||
return it->second;
|
||||
}
|
||||
|
||||
auto SpirvWriter::GetInputVariable(const std::string& name) const -> const ExtVar&
|
||||
{
|
||||
auto it = m_currentState->inputIds.find(name);
|
||||
assert(it != m_currentState->inputIds.end());
|
||||
|
||||
return it->second;
|
||||
}
|
||||
|
||||
auto SpirvWriter::GetOutputVariable(const std::string& name) const -> const ExtVar&
|
||||
{
|
||||
auto it = m_currentState->outputIds.find(name);
|
||||
assert(it != m_currentState->outputIds.end());
|
||||
|
||||
return it->second;
|
||||
}
|
||||
|
||||
auto SpirvWriter::GetUniformVariable(const std::string& name) const -> const ExtVar&
|
||||
{
|
||||
auto it = m_currentState->uniformIds.find(name);
|
||||
assert(it != m_currentState->uniformIds.end());
|
||||
|
||||
return it.value();
|
||||
}
|
||||
|
||||
UInt32 SpirvWriter::GetPointerTypeId(const ShaderAst::ExpressionType& type, SpirvStorageClass storageClass) const
|
||||
{
|
||||
return m_currentState->constantTypeCache.GetId(*SpirvConstantCache::BuildPointerType(type, storageClass));
|
||||
}
|
||||
|
||||
UInt32 SpirvWriter::GetTypeId(const ShaderAst::ExpressionType& type) const
|
||||
{
|
||||
return m_currentState->constantTypeCache.GetId(*SpirvConstantCache::BuildType(type));
|
||||
}
|
||||
|
||||
UInt32 SpirvWriter::ReadInputVariable(const std::string& name)
|
||||
{
|
||||
auto it = m_currentState->inputIds.find(name);
|
||||
assert(it != m_currentState->inputIds.end());
|
||||
|
||||
return ReadVariable(it.value());
|
||||
}
|
||||
|
||||
std::optional<UInt32> SpirvWriter::ReadInputVariable(const std::string& name, OnlyCache)
|
||||
{
|
||||
auto it = m_currentState->inputIds.find(name);
|
||||
assert(it != m_currentState->inputIds.end());
|
||||
|
||||
return ReadVariable(it.value(), OnlyCache{});
|
||||
}
|
||||
|
||||
UInt32 SpirvWriter::ReadLocalVariable(const std::string& name)
|
||||
{
|
||||
auto it = m_currentState->varToResult.find(name);
|
||||
assert(it != m_currentState->varToResult.end());
|
||||
|
||||
return it->second;
|
||||
}
|
||||
|
||||
std::optional<UInt32> SpirvWriter::ReadLocalVariable(const std::string& name, OnlyCache)
|
||||
{
|
||||
auto it = m_currentState->varToResult.find(name);
|
||||
if (it == m_currentState->varToResult.end())
|
||||
return {};
|
||||
|
||||
return it->second;
|
||||
}
|
||||
|
||||
UInt32 SpirvWriter::ReadParameterVariable(const std::string& name)
|
||||
{
|
||||
auto it = m_currentState->parameterIds.find(name);
|
||||
assert(it != m_currentState->parameterIds.end());
|
||||
|
||||
return ReadVariable(it.value());
|
||||
}
|
||||
|
||||
std::optional<UInt32> SpirvWriter::ReadParameterVariable(const std::string& name, OnlyCache)
|
||||
{
|
||||
auto it = m_currentState->parameterIds.find(name);
|
||||
assert(it != m_currentState->parameterIds.end());
|
||||
|
||||
return ReadVariable(it.value(), OnlyCache{});
|
||||
}
|
||||
|
||||
UInt32 SpirvWriter::ReadUniformVariable(const std::string& name)
|
||||
{
|
||||
auto it = m_currentState->uniformIds.find(name);
|
||||
assert(it != m_currentState->uniformIds.end());
|
||||
|
||||
return ReadVariable(it.value());
|
||||
}
|
||||
|
||||
std::optional<UInt32> SpirvWriter::ReadUniformVariable(const std::string& name, OnlyCache)
|
||||
{
|
||||
auto it = m_currentState->uniformIds.find(name);
|
||||
assert(it != m_currentState->uniformIds.end());
|
||||
|
||||
return ReadVariable(it.value(), OnlyCache{});
|
||||
}
|
||||
|
||||
UInt32 SpirvWriter::ReadVariable(ExtVar& var)
|
||||
{
|
||||
if (!var.valueId.has_value())
|
||||
std::optional<UInt32> fragmentFuncId;
|
||||
for (auto& func : m_currentState->funcs)
|
||||
{
|
||||
UInt32 resultId = AllocateResultId();
|
||||
m_currentState->functionBlocks.back().Append(SpirvOp::OpLoad, var.typeId, resultId, var.varId);
|
||||
m_currentState->debugInfo.Append(SpirvOp::OpName, func.funcId, func.name);
|
||||
|
||||
var.valueId = resultId;
|
||||
if (func.entryPointData)
|
||||
{
|
||||
auto& entryPointData = func.entryPointData.value();
|
||||
|
||||
SpirvExecutionModel execModel;
|
||||
|
||||
switch (entryPointData.stageType)
|
||||
{
|
||||
case ShaderStageType::Fragment:
|
||||
execModel = SpirvExecutionModel::Fragment;
|
||||
break;
|
||||
|
||||
case ShaderStageType::Vertex:
|
||||
execModel = SpirvExecutionModel::Vertex;
|
||||
break;
|
||||
|
||||
default:
|
||||
throw std::runtime_error("not yet implemented");
|
||||
}
|
||||
|
||||
m_currentState->header.AppendVariadic(SpirvOp::OpEntryPoint, [&](const auto& appender)
|
||||
{
|
||||
appender(execModel);
|
||||
appender(func.funcId);
|
||||
appender(func.name);
|
||||
|
||||
for (const auto& input : entryPointData.inputs)
|
||||
appender(input.varId);
|
||||
|
||||
for (const auto& output : entryPointData.outputs)
|
||||
appender(output.varId);
|
||||
});
|
||||
|
||||
if (entryPointData.stageType == ShaderStageType::Fragment)
|
||||
fragmentFuncId = func.funcId;
|
||||
}
|
||||
}
|
||||
|
||||
return var.valueId.value();
|
||||
}
|
||||
if (fragmentFuncId)
|
||||
m_currentState->header.Append(SpirvOp::OpExecutionMode, *fragmentFuncId, SpirvExecutionMode::OriginUpperLeft);
|
||||
|
||||
std::optional<UInt32> SpirvWriter::ReadVariable(const ExtVar& var, OnlyCache)
|
||||
{
|
||||
if (!var.valueId.has_value())
|
||||
return {};
|
||||
|
||||
return var.valueId.value();
|
||||
}
|
||||
|
||||
UInt32 SpirvWriter::RegisterConstant(const ShaderConstantValue& value)
|
||||
{
|
||||
return m_currentState->constantTypeCache.Register(*SpirvConstantCache::BuildConstant(value));
|
||||
}
|
||||
|
||||
UInt32 SpirvWriter::RegisterFunctionType(const ShaderAst::DeclareFunctionStatement& functionNode)
|
||||
{
|
||||
return m_currentState->constantTypeCache.Register({ *BuildFunctionType(functionNode) });
|
||||
}
|
||||
|
||||
UInt32 SpirvWriter::RegisterPointerType(ShaderAst::ExpressionType type, SpirvStorageClass storageClass)
|
||||
{
|
||||
return m_currentState->constantTypeCache.Register(*SpirvConstantCache::BuildPointerType(type, storageClass));
|
||||
}
|
||||
|
||||
UInt32 SpirvWriter::RegisterType(ShaderAst::ExpressionType type)
|
||||
{
|
||||
assert(m_currentState);
|
||||
return m_currentState->constantTypeCache.Register(*SpirvConstantCache::BuildType(type));
|
||||
}
|
||||
|
||||
void SpirvWriter::WriteLocalVariable(std::string name, UInt32 resultId)
|
||||
{
|
||||
assert(m_currentState);
|
||||
m_currentState->varToResult.insert_or_assign(std::move(name), resultId);
|
||||
}
|
||||
|
||||
SpirvConstantCache::TypePtr SpirvWriter::BuildFunctionType(const ShaderAst::DeclareFunctionStatement& functionNode)
|
||||
@@ -686,7 +574,56 @@ namespace Nz
|
||||
for (const auto& parameter : functionNode.parameters)
|
||||
parameterTypes.push_back(parameter.type);
|
||||
|
||||
return SpirvConstantCache::BuildFunctionType(functionNode.returnType, parameterTypes);
|
||||
return m_currentState->constantTypeCache.BuildFunctionType(functionNode.returnType, parameterTypes);
|
||||
}
|
||||
|
||||
UInt32 SpirvWriter::GetConstantId(const ShaderAst::ConstantValue& value) const
|
||||
{
|
||||
return m_currentState->constantTypeCache.GetId(*m_currentState->constantTypeCache.BuildConstant(value));
|
||||
}
|
||||
|
||||
UInt32 SpirvWriter::GetExtVarPointerId(std::size_t extVarIndex) const
|
||||
{
|
||||
auto it = m_currentState->preVisitor->extVars.find(extVarIndex);
|
||||
assert(it != m_currentState->preVisitor->extVars.end());
|
||||
|
||||
return it->second.pointerId;
|
||||
}
|
||||
|
||||
UInt32 SpirvWriter::GetFunctionTypeId(const ShaderAst::DeclareFunctionStatement& functionNode)
|
||||
{
|
||||
return m_currentState->constantTypeCache.GetId({ *BuildFunctionType(functionNode) });
|
||||
}
|
||||
|
||||
UInt32 SpirvWriter::GetPointerTypeId(const ShaderAst::ExpressionType& type, SpirvStorageClass storageClass) const
|
||||
{
|
||||
return m_currentState->constantTypeCache.GetId(*m_currentState->constantTypeCache.BuildPointerType(type, storageClass));
|
||||
}
|
||||
|
||||
UInt32 SpirvWriter::GetTypeId(const ShaderAst::ExpressionType& type) const
|
||||
{
|
||||
return m_currentState->constantTypeCache.GetId(*m_currentState->constantTypeCache.BuildType(type));
|
||||
}
|
||||
|
||||
UInt32 SpirvWriter::RegisterConstant(const ShaderAst::ConstantValue& value)
|
||||
{
|
||||
return m_currentState->constantTypeCache.Register(*m_currentState->constantTypeCache.BuildConstant(value));
|
||||
}
|
||||
|
||||
UInt32 SpirvWriter::RegisterFunctionType(const ShaderAst::DeclareFunctionStatement& functionNode)
|
||||
{
|
||||
return m_currentState->constantTypeCache.Register({ *BuildFunctionType(functionNode) });
|
||||
}
|
||||
|
||||
UInt32 SpirvWriter::RegisterPointerType(ShaderAst::ExpressionType type, SpirvStorageClass storageClass)
|
||||
{
|
||||
return m_currentState->constantTypeCache.Register(*m_currentState->constantTypeCache.BuildPointerType(type, storageClass));
|
||||
}
|
||||
|
||||
UInt32 SpirvWriter::RegisterType(ShaderAst::ExpressionType type)
|
||||
{
|
||||
assert(m_currentState);
|
||||
return m_currentState->constantTypeCache.Register(*m_currentState->constantTypeCache.BuildType(type));
|
||||
}
|
||||
|
||||
void SpirvWriter::MergeSections(std::vector<UInt32>& output, const SpirvSection& from)
|
||||
|
||||
Reference in New Issue
Block a user