// Copyright (C) 2020 Jérôme Leclercq // This file is part of the "Nazara Engine - Shader generator" // For conditions of distribution and use, see copyright notice in Config.hpp #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace Nz { namespace { class PreVisitor : public ShaderAst::AstRecursiveVisitor { public: using ExtInstList = std::unordered_set; using LocalContainer = std::unordered_set; PreVisitor(ShaderAst::AstCache* cache, const SpirvWriter::States& conditions, SpirvConstantCache& constantCache) : m_cache(cache), m_conditions(conditions), m_constantCache(constantCache) { } void Visit(ShaderAst::AccessMemberExpression& node) override { /*for (std::size_t index : node.memberIdentifiers) m_constantCache.Register(*SpirvConstantCache::BuildConstant(Int32(index)));*/ AstRecursiveVisitor::Visit(node); } void Visit(ShaderAst::ConditionalExpression& node) override { /*std::size_t conditionIndex = m_shader.FindConditionByName(node.conditionName); assert(conditionIndex != ShaderAst::InvalidCondition); if (TestBit(m_conditions.enabledConditions, conditionIndex)) Visit(node.truePath); else Visit(node.falsePath);*/ } void Visit(ShaderAst::ConditionalStatement& node) override { /*std::size_t conditionIndex = m_shader.FindConditionByName(node.conditionName); assert(conditionIndex != ShaderAst::InvalidCondition); if (TestBit(m_conditions.enabledConditions, conditionIndex)) Visit(node.statement);*/ } void Visit(ShaderAst::ConstantExpression& node) override { std::visit([&](auto&& arg) { m_constantCache.Register(*SpirvConstantCache::BuildConstant(arg)); }, node.value); AstRecursiveVisitor::Visit(node); } void Visit(ShaderAst::DeclareFunctionStatement& node) override { m_constantCache.Register(*SpirvConstantCache::BuildType(node.returnType)); for (auto& parameter : node.parameters) m_constantCache.Register(*SpirvConstantCache::BuildType(parameter.type)); } void Visit(ShaderAst::DeclareStructStatement& node) override { for (auto& field : node.description.members) m_constantCache.Register(*SpirvConstantCache::BuildType(field.type)); } void Visit(ShaderAst::DeclareVariableStatement& node) override { variableTypes.insert(node.varType); AstRecursiveVisitor::Visit(node); } void Visit(ShaderAst::IdentifierExpression& node) override { variableTypes.insert(GetExpressionType(node, m_cache)); AstRecursiveVisitor::Visit(node); } void Visit(ShaderAst::IntrinsicExpression& node) override { AstRecursiveVisitor::Visit(node); switch (node.intrinsic) { // Require GLSL.std.450 case ShaderAst::IntrinsicType::CrossProduct: extInsts.emplace("GLSL.std.450"); break; // Part of SPIR-V core case ShaderAst::IntrinsicType::DotProduct: break; } } ExtInstList extInsts; LocalContainer variableTypes; private: ShaderAst::AstCache* m_cache; const SpirvWriter::States& m_conditions; SpirvConstantCache& m_constantCache; }; template constexpr ShaderAst::BasicType GetBasicType() { if constexpr (std::is_same_v) return ShaderAst::BasicType::Boolean; else if constexpr (std::is_same_v) return(ShaderAst::BasicType::Float1); else if constexpr (std::is_same_v) return(ShaderAst::BasicType::Int1); else if constexpr (std::is_same_v) return(ShaderAst::BasicType::Float2); else if constexpr (std::is_same_v) return(ShaderAst::BasicType::Float3); else if constexpr (std::is_same_v) return(ShaderAst::BasicType::Float4); else if constexpr (std::is_same_v) return(ShaderAst::BasicType::Int2); else if constexpr (std::is_same_v) return(ShaderAst::BasicType::Int3); else if constexpr (std::is_same_v) return(ShaderAst::BasicType::Int4); else static_assert(AlwaysFalse::value, "unhandled type"); } } struct SpirvWriter::State { State() : constantTypeCache(nextVarIndex) { } struct Func { UInt32 typeId; UInt32 id; }; tsl::ordered_map inputIds; tsl::ordered_map outputIds; tsl::ordered_map parameterIds; tsl::ordered_map uniformIds; std::unordered_map extensionInstructions; std::unordered_map builtinIds; std::unordered_map varToResult; std::vector funcs; std::vector functionBlocks; std::vector resultIds; UInt32 nextVarIndex = 1; SpirvConstantCache constantTypeCache; //< init after nextVarIndex // Output SpirvSection header; SpirvSection constants; SpirvSection debugInfo; SpirvSection annotations; SpirvSection instructions; }; SpirvWriter::SpirvWriter() : m_currentState(nullptr) { } std::vector SpirvWriter::Generate(ShaderAst::StatementPtr& shader, const States& conditions) { std::string error; if (!ShaderAst::ValidateAst(shader, &error, &m_context.cache)) throw std::runtime_error("Invalid shader AST: " + error); m_context.states = &conditions; State state; m_currentState = &state; CallOnExit onExit([this]() { m_currentState = nullptr; }); std::vector functionStatements; ShaderAst::AstCloner cloner; // Register all extended instruction sets PreVisitor preVisitor(&m_context.cache, conditions, state.constantTypeCache); shader->Visit(preVisitor); for (const std::string& extInst : preVisitor.extInsts) state.extensionInstructions[extInst] = AllocateResultId(); // Register all types /*for (const auto& func : shader.GetFunctions()) { RegisterType(func.returnType); for (const auto& param : func.parameters) RegisterType(param.type); } for (const auto& input : shader.GetInputs()) RegisterPointerType(input.type, SpirvStorageClass::Input); for (const auto& output : shader.GetOutputs()) RegisterPointerType(output.type, SpirvStorageClass::Output); for (const auto& uniform : shader.GetUniforms()) RegisterPointerType(uniform.type, (IsSamplerType(uniform.type)) ? SpirvStorageClass::UniformConstant : SpirvStorageClass::Uniform); for (const auto& func : shader.GetFunctions()) RegisterFunctionType(func.returnType, func.parameters); for (const auto& type : preVisitor.variableTypes) RegisterType(type); for (const auto& builtin : preVisitor.builtinVars) RegisterType(builtin->type); // Register result id and debug infos for global variables/functions for (const auto& builtin : preVisitor.builtinVars) { SpirvConstantCache::Variable variable; SpirvBuiltIn builtinDecoration; switch (builtin->entry) { case ShaderAst::BuiltinEntry::VertexPosition: variable.debugName = "builtin_VertexPosition"; variable.storageClass = SpirvStorageClass::Output; builtinDecoration = SpirvBuiltIn::Position; break; default: throw std::runtime_error("unexpected builtin type"); } const ShaderAst::ShaderExpressionType& builtinExprType = builtin->type; assert(IsBasicType(builtinExprType)); ShaderAst::BasicType builtinType = std::get(builtinExprType); variable.type = SpirvConstantCache::BuildPointerType(builtinType, variable.storageClass); UInt32 varId = m_currentState->constantTypeCache.Register(variable); ExtVar builtinData; builtinData.pointerTypeId = GetPointerTypeId(builtinType, variable.storageClass); builtinData.typeId = GetTypeId(builtinType); builtinData.varId = varId; state.annotations.Append(SpirvOp::OpDecorate, builtinData.varId, SpirvDecoration::BuiltIn, builtinDecoration); state.builtinIds.emplace(builtin->entry, builtinData); } for (const auto& input : shader.GetInputs()) { SpirvConstantCache::Variable variable; variable.debugName = input.name; variable.storageClass = SpirvStorageClass::Input; variable.type = SpirvConstantCache::BuildPointerType(shader, input.type, variable.storageClass); UInt32 varId = m_currentState->constantTypeCache.Register(variable); ExtVar inputData; inputData.pointerTypeId = GetPointerTypeId(input.type, variable.storageClass); inputData.typeId = GetTypeId(input.type); inputData.varId = varId; state.inputIds.emplace(input.name, std::move(inputData)); if (input.locationIndex) state.annotations.Append(SpirvOp::OpDecorate, varId, SpirvDecoration::Location, *input.locationIndex); } for (const auto& output : shader.GetOutputs()) { SpirvConstantCache::Variable variable; variable.debugName = output.name; variable.storageClass = SpirvStorageClass::Output; variable.type = SpirvConstantCache::BuildPointerType(shader, output.type, variable.storageClass); UInt32 varId = m_currentState->constantTypeCache.Register(variable); ExtVar outputData; outputData.pointerTypeId = GetPointerTypeId(output.type, variable.storageClass); outputData.typeId = GetTypeId(output.type); outputData.varId = varId; state.outputIds.emplace(output.name, std::move(outputData)); if (output.locationIndex) state.annotations.Append(SpirvOp::OpDecorate, varId, SpirvDecoration::Location, *output.locationIndex); } for (const auto& uniform : shader.GetUniforms()) { SpirvConstantCache::Variable variable; variable.debugName = uniform.name; variable.storageClass = (IsSamplerType(uniform.type)) ? SpirvStorageClass::UniformConstant : SpirvStorageClass::Uniform; variable.type = SpirvConstantCache::BuildPointerType(shader, uniform.type, variable.storageClass); UInt32 varId = m_currentState->constantTypeCache.Register(variable); ExtVar uniformData; uniformData.pointerTypeId = GetPointerTypeId(uniform.type, variable.storageClass); uniformData.typeId = GetTypeId(uniform.type); uniformData.varId = varId; state.uniformIds.emplace(uniform.name, std::move(uniformData)); if (uniform.bindingIndex) { state.annotations.Append(SpirvOp::OpDecorate, varId, SpirvDecoration::Binding, *uniform.bindingIndex); state.annotations.Append(SpirvOp::OpDecorate, varId, SpirvDecoration::DescriptorSet, 0); } } for (const auto& func : shader.GetFunctions()) { auto& funcData = state.funcs.emplace_back(); funcData.id = AllocateResultId(); funcData.typeId = GetFunctionTypeId(func.returnType, func.parameters); state.debugInfo.Append(SpirvOp::OpName, funcData.id, func.name); } std::size_t entryPointIndex = std::numeric_limits::max(); for (std::size_t funcIndex = 0; funcIndex < shader.GetFunctionCount(); ++funcIndex) { const auto& func = shader.GetFunction(funcIndex); if (func.name == "main") entryPointIndex = funcIndex; auto& funcData = state.funcs[funcIndex]; state.instructions.Append(SpirvOp::OpFunction, GetTypeId(func.returnType), funcData.id, 0, funcData.typeId); state.functionBlocks.clear(); state.functionBlocks.emplace_back(*this); state.parameterIds.clear(); for (const auto& param : func.parameters) { UInt32 paramResultId = AllocateResultId(); state.instructions.Append(SpirvOp::OpFunctionParameter, GetTypeId(param.type), paramResultId); ExtVar parameterData; parameterData.pointerTypeId = GetPointerTypeId(param.type, SpirvStorageClass::Function); parameterData.typeId = GetTypeId(param.type); parameterData.varId = paramResultId; state.parameterIds.emplace(param.name, std::move(parameterData)); } SpirvAstVisitor visitor(*this, state.functionBlocks); visitor.Visit(functionStatements[funcIndex]); if (!state.functionBlocks.back().IsTerminated()) { assert(func.returnType == ShaderAst::ShaderExpressionType(ShaderAst::BasicType::Void)); state.functionBlocks.back().Append(SpirvOp::OpReturn); } for (SpirvBlock& block : state.functionBlocks) state.instructions.AppendSection(block); state.instructions.Append(SpirvOp::OpFunctionEnd); } m_currentState->constantTypeCache.Write(m_currentState->annotations, m_currentState->constants, m_currentState->debugInfo); AppendHeader(); if (entryPointIndex != std::numeric_limits::max()) { SpvExecutionModel execModel; const auto& entryFuncData = shader.GetFunction(entryPointIndex); const auto& entryFunc = state.funcs[entryPointIndex]; assert(m_context.shader); switch (m_context.shader->GetStage()) { case ShaderStageType::Fragment: execModel = SpvExecutionModelFragment; break; case ShaderStageType::Vertex: execModel = SpvExecutionModelVertex; break; default: throw std::runtime_error("not yet implemented"); } // OpEntryPoint Vertex %main "main" %outNormal %inNormals %outTexCoords %inTexCoord %_ %inPos state.header.AppendVariadic(SpirvOp::OpEntryPoint, [&](const auto& appender) { appender(execModel); appender(entryFunc.id); appender(entryFuncData.name); for (const auto& [name, varData] : state.builtinIds) appender(varData.varId); for (const auto& [name, varData] : state.inputIds) appender(varData.varId); for (const auto& [name, varData] : state.outputIds) appender(varData.varId); }); if (m_context.shader->GetStage() == ShaderStageType::Fragment) state.header.Append(SpirvOp::OpExecutionMode, entryFunc.id, SpvExecutionModeOriginUpperLeft); }*/ std::vector ret; /*MergeSections(ret, state.header); MergeSections(ret, state.debugInfo); MergeSections(ret, state.annotations); MergeSections(ret, state.constants); MergeSections(ret, state.instructions);*/ return ret; } void SpirvWriter::SetEnv(Environment environment) { m_environment = std::move(environment); } UInt32 SpirvWriter::AllocateResultId() { return m_currentState->nextVarIndex++; } void SpirvWriter::AppendHeader() { m_currentState->header.AppendRaw(SpvMagicNumber); //< Spir-V magic number UInt32 version = (m_environment.spvMajorVersion << 16) | m_environment.spvMinorVersion << 8; m_currentState->header.AppendRaw(version); //< Spir-V version number (1.0 for compatibility) m_currentState->header.AppendRaw(0); //< Generator identifier (TODO: Register generator to Khronos) m_currentState->header.AppendRaw(m_currentState->nextVarIndex); //< Bound (ID count) m_currentState->header.AppendRaw(0); //< Instruction schema (required to be 0 for now) m_currentState->header.Append(SpirvOp::OpCapability, SpvCapabilityShader); for (const auto& [extInst, resultId] : m_currentState->extensionInstructions) m_currentState->header.Append(SpirvOp::OpExtInstImport, resultId, extInst); m_currentState->header.Append(SpirvOp::OpMemoryModel, SpvAddressingModelLogical, SpvMemoryModelGLSL450); } SpirvConstantCache::Function SpirvWriter::BuildFunctionType(ShaderAst::ShaderExpressionType retType, const std::vector& parameters) { std::vector parameterTypes; parameterTypes.reserve(parameters.size()); for (const auto& parameter : parameters) parameterTypes.push_back(SpirvConstantCache::BuildPointerType(parameter.type, SpirvStorageClass::Function)); return SpirvConstantCache::Function{ SpirvConstantCache::BuildType(retType), std::move(parameterTypes) }; } UInt32 SpirvWriter::GetConstantId(const ShaderConstantValue& value) const { return m_currentState->constantTypeCache.GetId(*SpirvConstantCache::BuildConstant(value)); } UInt32 SpirvWriter::GetFunctionTypeId(ShaderAst::ShaderExpressionType retType, const std::vector& parameters) { return m_currentState->constantTypeCache.GetId({ BuildFunctionType(retType, parameters) }); } auto SpirvWriter::GetBuiltinVariable(ShaderAst::BuiltinEntry builtin) const -> const ExtVar& { auto it = m_currentState->builtinIds.find(builtin); assert(it != m_currentState->builtinIds.end()); return it->second; } auto SpirvWriter::GetInputVariable(const std::string& name) const -> const ExtVar& { auto it = m_currentState->inputIds.find(name); assert(it != m_currentState->inputIds.end()); return it->second; } auto SpirvWriter::GetOutputVariable(const std::string& name) const -> const ExtVar& { auto it = m_currentState->outputIds.find(name); assert(it != m_currentState->outputIds.end()); return it->second; } auto SpirvWriter::GetUniformVariable(const std::string& name) const -> const ExtVar& { auto it = m_currentState->uniformIds.find(name); assert(it != m_currentState->uniformIds.end()); return it.value(); } UInt32 SpirvWriter::GetPointerTypeId(const ShaderAst::ShaderExpressionType& type, SpirvStorageClass storageClass) const { return m_currentState->constantTypeCache.GetId(*SpirvConstantCache::BuildPointerType(type, storageClass)); } UInt32 SpirvWriter::GetTypeId(const ShaderAst::ShaderExpressionType& type) const { return m_currentState->constantTypeCache.GetId(*SpirvConstantCache::BuildType(type)); } UInt32 SpirvWriter::ReadInputVariable(const std::string& name) { auto it = m_currentState->inputIds.find(name); assert(it != m_currentState->inputIds.end()); return ReadVariable(it.value()); } std::optional SpirvWriter::ReadInputVariable(const std::string& name, OnlyCache) { auto it = m_currentState->inputIds.find(name); assert(it != m_currentState->inputIds.end()); return ReadVariable(it.value(), OnlyCache{}); } UInt32 SpirvWriter::ReadLocalVariable(const std::string& name) { auto it = m_currentState->varToResult.find(name); assert(it != m_currentState->varToResult.end()); return it->second; } std::optional SpirvWriter::ReadLocalVariable(const std::string& name, OnlyCache) { auto it = m_currentState->varToResult.find(name); if (it == m_currentState->varToResult.end()) return {}; return it->second; } UInt32 SpirvWriter::ReadParameterVariable(const std::string& name) { auto it = m_currentState->parameterIds.find(name); assert(it != m_currentState->parameterIds.end()); return ReadVariable(it.value()); } std::optional SpirvWriter::ReadParameterVariable(const std::string& name, OnlyCache) { auto it = m_currentState->parameterIds.find(name); assert(it != m_currentState->parameterIds.end()); return ReadVariable(it.value(), OnlyCache{}); } UInt32 SpirvWriter::ReadUniformVariable(const std::string& name) { auto it = m_currentState->uniformIds.find(name); assert(it != m_currentState->uniformIds.end()); return ReadVariable(it.value()); } std::optional SpirvWriter::ReadUniformVariable(const std::string& name, OnlyCache) { auto it = m_currentState->uniformIds.find(name); assert(it != m_currentState->uniformIds.end()); return ReadVariable(it.value(), OnlyCache{}); } UInt32 SpirvWriter::ReadVariable(ExtVar& var) { if (!var.valueId.has_value()) { UInt32 resultId = AllocateResultId(); m_currentState->functionBlocks.back().Append(SpirvOp::OpLoad, var.typeId, resultId, var.varId); var.valueId = resultId; } return var.valueId.value(); } std::optional SpirvWriter::ReadVariable(const ExtVar& var, OnlyCache) { if (!var.valueId.has_value()) return {}; return var.valueId.value(); } UInt32 SpirvWriter::RegisterConstant(const ShaderConstantValue& value) { return m_currentState->constantTypeCache.Register(*SpirvConstantCache::BuildConstant(value)); } UInt32 SpirvWriter::RegisterFunctionType(ShaderAst::ShaderExpressionType retType, const std::vector& parameters) { return m_currentState->constantTypeCache.Register({ BuildFunctionType(retType, parameters) }); } UInt32 SpirvWriter::RegisterPointerType(ShaderAst::ShaderExpressionType type, SpirvStorageClass storageClass) { return m_currentState->constantTypeCache.Register(*SpirvConstantCache::BuildPointerType(type, storageClass)); } UInt32 SpirvWriter::RegisterType(ShaderAst::ShaderExpressionType type) { assert(m_currentState); return m_currentState->constantTypeCache.Register(*SpirvConstantCache::BuildType(type)); } void SpirvWriter::WriteLocalVariable(std::string name, UInt32 resultId) { assert(m_currentState); m_currentState->varToResult.insert_or_assign(std::move(name), resultId); } void SpirvWriter::MergeSections(std::vector& output, const SpirvSection& from) { const std::vector& bytecode = from.GetBytecode(); std::size_t prevSize = output.size(); output.resize(prevSize + bytecode.size()); std::copy(bytecode.begin(), bytecode.end(), output.begin() + prevSize); } }