From 93de44d29304283ca2e0489269a40cf0cf33d7b5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=B4me=20Leclercq?= Date: Sun, 23 Aug 2020 18:32:28 +0200 Subject: [PATCH] Big SpirVWriter refactor --- include/Nazara/Shader/ShaderAstVisitor.hpp | 4 +- .../Nazara/Shader/ShaderAstVisitorExcept.hpp | 36 + .../Nazara/Shader/ShaderVarVisitorExcept.hpp | 28 + include/Nazara/Shader/SpirvExpressionLoad.hpp | 64 ++ include/Nazara/Shader/SpirvExpressionLoad.inl | 16 + .../SpirvExpressionLoadAccessMember.hpp | 62 ++ .../SpirvExpressionLoadAccessMember.inl | 16 + .../Nazara/Shader/SpirvExpressionStore.hpp | 63 ++ .../Nazara/Shader/SpirvExpressionStore.inl | 16 + include/Nazara/Shader/SpirvPrinter.hpp | 42 ++ include/Nazara/Shader/SpirvPrinter.inl | 16 + .../Nazara/Shader/SpirvStatementVisitor.hpp | 43 ++ .../Nazara/Shader/SpirvStatementVisitor.inl | 16 + include/Nazara/Shader/SpirvWriter.hpp | 57 +- src/Nazara/Shader/ShaderAstVisitorExcept.cpp | 75 ++ src/Nazara/Shader/ShaderVarVisitorExcept.cpp | 40 ++ src/Nazara/Shader/SpirvExpressionLoad.cpp | 448 ++++++++++++ .../SpirvExpressionLoadAccessMember.cpp | 116 +++ src/Nazara/Shader/SpirvExpressionStore.cpp | 104 +++ src/Nazara/Shader/SpirvPrinter.cpp | 231 ++++++ src/Nazara/Shader/SpirvStatementVisitor.cpp | 49 ++ src/Nazara/Shader/SpirvWriter.cpp | 680 +++--------------- 22 files changed, 1604 insertions(+), 618 deletions(-) create mode 100644 include/Nazara/Shader/ShaderAstVisitorExcept.hpp create mode 100644 include/Nazara/Shader/ShaderVarVisitorExcept.hpp create mode 100644 include/Nazara/Shader/SpirvExpressionLoad.hpp create mode 100644 include/Nazara/Shader/SpirvExpressionLoad.inl create mode 100644 include/Nazara/Shader/SpirvExpressionLoadAccessMember.hpp create mode 100644 include/Nazara/Shader/SpirvExpressionLoadAccessMember.inl create mode 100644 include/Nazara/Shader/SpirvExpressionStore.hpp create mode 100644 include/Nazara/Shader/SpirvExpressionStore.inl create mode 100644 include/Nazara/Shader/SpirvPrinter.hpp create mode 100644 include/Nazara/Shader/SpirvPrinter.inl create mode 100644 include/Nazara/Shader/SpirvStatementVisitor.hpp create mode 100644 include/Nazara/Shader/SpirvStatementVisitor.inl create mode 100644 src/Nazara/Shader/ShaderAstVisitorExcept.cpp create mode 100644 src/Nazara/Shader/ShaderVarVisitorExcept.cpp create mode 100644 src/Nazara/Shader/SpirvExpressionLoad.cpp create mode 100644 src/Nazara/Shader/SpirvExpressionLoadAccessMember.cpp create mode 100644 src/Nazara/Shader/SpirvExpressionStore.cpp create mode 100644 src/Nazara/Shader/SpirvPrinter.cpp create mode 100644 src/Nazara/Shader/SpirvStatementVisitor.cpp diff --git a/include/Nazara/Shader/ShaderAstVisitor.hpp b/include/Nazara/Shader/ShaderAstVisitor.hpp index a6896cb2c..64dc35559 100644 --- a/include/Nazara/Shader/ShaderAstVisitor.hpp +++ b/include/Nazara/Shader/ShaderAstVisitor.hpp @@ -4,8 +4,8 @@ #pragma once -#ifndef NAZARA_SHADERVISITOR_HPP -#define NAZARA_SHADERVISITOR_HPP +#ifndef NAZARA_SHADERASTVISITOR_HPP +#define NAZARA_SHADERASTVISITOR_HPP #include #include diff --git a/include/Nazara/Shader/ShaderAstVisitorExcept.hpp b/include/Nazara/Shader/ShaderAstVisitorExcept.hpp new file mode 100644 index 000000000..40635477c --- /dev/null +++ b/include/Nazara/Shader/ShaderAstVisitorExcept.hpp @@ -0,0 +1,36 @@ +// Copyright (C) 2020 Jérôme Leclercq +// This file is part of the "Nazara Engine - Shader generator" +// For conditions of distribution and use, see copyright notice in Config.hpp + +#pragma once + +#ifndef NAZARA_SHADERASTVISITOREXCEPT_HPP +#define NAZARA_SHADERASTVISITOREXCEPT_HPP + +#include +#include +#include + +namespace Nz +{ + class NAZARA_SHADER_API ShaderAstVisitorExcept : public ShaderAstVisitor + { + public: + using ShaderAstVisitor::Visit; + void Visit(ShaderNodes::AccessMember& node) override; + void Visit(ShaderNodes::AssignOp& node) override; + void Visit(ShaderNodes::BinaryOp& node) override; + void Visit(ShaderNodes::Branch& node) override; + void Visit(ShaderNodes::Cast& node) override; + void Visit(ShaderNodes::Constant& node) override; + void Visit(ShaderNodes::DeclareVariable& node) override; + void Visit(ShaderNodes::ExpressionStatement& node) override; + void Visit(ShaderNodes::Identifier& node) override; + void Visit(ShaderNodes::IntrinsicCall& node) override; + void Visit(ShaderNodes::Sample2D& node) override; + void Visit(ShaderNodes::StatementBlock& node) override; + void Visit(ShaderNodes::SwizzleOp& node) override; + }; +} + +#endif diff --git a/include/Nazara/Shader/ShaderVarVisitorExcept.hpp b/include/Nazara/Shader/ShaderVarVisitorExcept.hpp new file mode 100644 index 000000000..3fa769e21 --- /dev/null +++ b/include/Nazara/Shader/ShaderVarVisitorExcept.hpp @@ -0,0 +1,28 @@ +// Copyright (C) 2015 Jérôme Leclercq +// This file is part of the "Nazara Engine - Shader generator" +// For conditions of distribution and use, see copyright notice in Config.hpp + +#pragma once + +#ifndef NAZARA_SHADERVARVISITOREXCEPT_HPP +#define NAZARA_SHADERVARVISITOREXCEPT_HPP + +#include +#include + +namespace Nz +{ + class NAZARA_SHADER_API ShaderVarVisitorExcept : public ShaderVarVisitor + { + public: + using ShaderVarVisitor::Visit; + void Visit(ShaderNodes::BuiltinVariable& var) override; + void Visit(ShaderNodes::InputVariable& var) override; + void Visit(ShaderNodes::LocalVariable& var) override; + void Visit(ShaderNodes::OutputVariable& var) override; + void Visit(ShaderNodes::ParameterVariable& var) override; + void Visit(ShaderNodes::UniformVariable& var) override; + }; +} + +#endif diff --git a/include/Nazara/Shader/SpirvExpressionLoad.hpp b/include/Nazara/Shader/SpirvExpressionLoad.hpp new file mode 100644 index 000000000..a766e568d --- /dev/null +++ b/include/Nazara/Shader/SpirvExpressionLoad.hpp @@ -0,0 +1,64 @@ +// Copyright (C) 2020 Jérôme Leclercq +// This file is part of the "Nazara Engine - Shader generator" +// For conditions of distribution and use, see copyright notice in Config.hpp + +#pragma once + +#ifndef NAZARA_SPIRVEXPRESSIONLOAD_HPP +#define NAZARA_SPIRVEXPRESSIONLOAD_HPP + +#include +#include +#include +#include +#include + +namespace Nz +{ + class SpirvWriter; + + class NAZARA_SHADER_API SpirvExpressionLoad : public ShaderAstVisitorExcept, public ShaderVarVisitorExcept + { + public: + inline SpirvExpressionLoad(SpirvWriter& writer); + SpirvExpressionLoad(const SpirvExpressionLoad&) = delete; + SpirvExpressionLoad(SpirvExpressionLoad&&) = delete; + ~SpirvExpressionLoad() = default; + + UInt32 EvaluateExpression(const ShaderNodes::ExpressionPtr& expr); + + using ShaderAstVisitorExcept::Visit; + void Visit(ShaderNodes::AccessMember& node) override; + void Visit(ShaderNodes::AssignOp& node) override; + void Visit(ShaderNodes::BinaryOp& node) override; + void Visit(ShaderNodes::Cast& node) override; + void Visit(ShaderNodes::Constant& node) override; + void Visit(ShaderNodes::DeclareVariable& node) override; + void Visit(ShaderNodes::ExpressionStatement& node) override; + void Visit(ShaderNodes::Identifier& node) override; + void Visit(ShaderNodes::IntrinsicCall& node) override; + void Visit(ShaderNodes::Sample2D& node) override; + void Visit(ShaderNodes::SwizzleOp& node) override; + + using ShaderVarVisitorExcept::Visit; + void Visit(ShaderNodes::BuiltinVariable& var) override; + void Visit(ShaderNodes::InputVariable& var) override; + void Visit(ShaderNodes::LocalVariable& var) override; + void Visit(ShaderNodes::ParameterVariable& var) override; + void Visit(ShaderNodes::UniformVariable& var) override; + + SpirvExpressionLoad& operator=(const SpirvExpressionLoad&) = delete; + SpirvExpressionLoad& operator=(SpirvExpressionLoad&&) = delete; + + private: + void PushResultId(UInt32 value); + UInt32 PopResultId(); + + std::vector m_resultIds; + SpirvWriter& m_writer; + }; +} + +#include + +#endif diff --git a/include/Nazara/Shader/SpirvExpressionLoad.inl b/include/Nazara/Shader/SpirvExpressionLoad.inl new file mode 100644 index 000000000..966aae912 --- /dev/null +++ b/include/Nazara/Shader/SpirvExpressionLoad.inl @@ -0,0 +1,16 @@ +// Copyright (C) 2020 Jérôme Leclercq +// This file is part of the "Nazara Engine - Shader generator" +// For conditions of distribution and use, see copyright notice in Config.hpp + +#include +#include + +namespace Nz +{ + inline SpirvExpressionLoad::SpirvExpressionLoad(SpirvWriter& writer) : + m_writer(writer) + { + } +} + +#include diff --git a/include/Nazara/Shader/SpirvExpressionLoadAccessMember.hpp b/include/Nazara/Shader/SpirvExpressionLoadAccessMember.hpp new file mode 100644 index 000000000..8e2e0ff3b --- /dev/null +++ b/include/Nazara/Shader/SpirvExpressionLoadAccessMember.hpp @@ -0,0 +1,62 @@ +// Copyright (C) 2020 Jérôme Leclercq +// This file is part of the "Nazara Engine - Shader generator" +// For conditions of distribution and use, see copyright notice in Config.hpp + +#pragma once + +#ifndef NAZARA_SPIRVEXPRESSIONLOADACCESSMEMBER_HPP +#define NAZARA_SPIRVEXPRESSIONLOADACCESSMEMBER_HPP + +#include +#include +#include +#include +#include +#include + +namespace Nz +{ + class SpirvWriter; + + class NAZARA_SHADER_API SpirvExpressionLoadAccessMember : public ShaderAstVisitorExcept, public ShaderVarVisitorExcept + { + public: + inline SpirvExpressionLoadAccessMember(SpirvWriter& writer); + SpirvExpressionLoadAccessMember(const SpirvExpressionLoadAccessMember&) = delete; + SpirvExpressionLoadAccessMember(SpirvExpressionLoadAccessMember&&) = delete; + ~SpirvExpressionLoadAccessMember() = default; + + UInt32 EvaluateExpression(ShaderNodes::AccessMember& expr); + + using ShaderAstVisitor::Visit; + void Visit(ShaderNodes::AccessMember& node) override; + void Visit(ShaderNodes::Identifier& node) override; + + using ShaderVarVisitor::Visit; + void Visit(ShaderNodes::InputVariable& var) override; + void Visit(ShaderNodes::UniformVariable& var) override; + + SpirvExpressionLoadAccessMember& operator=(const SpirvExpressionLoadAccessMember&) = delete; + SpirvExpressionLoadAccessMember& operator=(SpirvExpressionLoadAccessMember&&) = delete; + + private: + struct Pointer + { + SpirvStorageClass storage; + UInt32 resultId; + UInt32 pointedTypeId; + }; + + struct Value + { + UInt32 resultId; + }; + + SpirvWriter& m_writer; + std::variant m_value; + }; +} + +#include + +#endif diff --git a/include/Nazara/Shader/SpirvExpressionLoadAccessMember.inl b/include/Nazara/Shader/SpirvExpressionLoadAccessMember.inl new file mode 100644 index 000000000..d81cfbb9c --- /dev/null +++ b/include/Nazara/Shader/SpirvExpressionLoadAccessMember.inl @@ -0,0 +1,16 @@ +// Copyright (C) 2020 Jérôme Leclercq +// This file is part of the "Nazara Engine - Shader generator" +// For conditions of distribution and use, see copyright notice in Config.hpp + +#include +#include + +namespace Nz +{ + inline SpirvExpressionLoadAccessMember::SpirvExpressionLoadAccessMember(SpirvWriter& writer) : + m_writer(writer) + { + } +} + +#include diff --git a/include/Nazara/Shader/SpirvExpressionStore.hpp b/include/Nazara/Shader/SpirvExpressionStore.hpp new file mode 100644 index 000000000..d7d37e39d --- /dev/null +++ b/include/Nazara/Shader/SpirvExpressionStore.hpp @@ -0,0 +1,63 @@ +// Copyright (C) 2020 Jérôme Leclercq +// This file is part of the "Nazara Engine - Shader generator" +// For conditions of distribution and use, see copyright notice in Config.hpp + +#pragma once + +#ifndef NAZARA_SPIRVEXPRESSIONSTORE_HPP +#define NAZARA_SPIRVEXPRESSIONSTORE_HPP + +#include +#include +#include +#include +#include + +namespace Nz +{ + class SpirvSection; + class SpirvWriter; + + class NAZARA_SHADER_API SpirvExpressionStore : public ShaderAstVisitorExcept, public ShaderVarVisitorExcept + { + public: + inline SpirvExpressionStore(SpirvWriter& writer); + SpirvExpressionStore(const SpirvExpressionStore&) = delete; + SpirvExpressionStore(SpirvExpressionStore&&) = delete; + ~SpirvExpressionStore() = default; + + void Store(const ShaderNodes::ExpressionPtr& node, UInt32 resultId); + + using ShaderAstVisitorExcept::Visit; + void Visit(ShaderNodes::AccessMember& node) override; + void Visit(ShaderNodes::Identifier& node) override; + void Visit(ShaderNodes::SwizzleOp& node) override; + + using ShaderVarVisitorExcept::Visit; + void Visit(ShaderNodes::BuiltinVariable& var) override; + void Visit(ShaderNodes::LocalVariable& var) override; + void Visit(ShaderNodes::OutputVariable& var) override; + + SpirvExpressionStore& operator=(const SpirvExpressionStore&) = delete; + SpirvExpressionStore& operator=(SpirvExpressionStore&&) = delete; + + private: + struct LocalVar + { + std::string varName; + }; + + struct Pointer + { + SpirvStorageClass storage; + UInt32 resultId; + }; + + SpirvWriter& m_writer; + std::variant m_value; + }; +} + +#include + +#endif diff --git a/include/Nazara/Shader/SpirvExpressionStore.inl b/include/Nazara/Shader/SpirvExpressionStore.inl new file mode 100644 index 000000000..558a2aee8 --- /dev/null +++ b/include/Nazara/Shader/SpirvExpressionStore.inl @@ -0,0 +1,16 @@ +// Copyright (C) 2020 Jérôme Leclercq +// This file is part of the "Nazara Engine - Shader generator" +// For conditions of distribution and use, see copyright notice in Config.hpp + +#include +#include + +namespace Nz +{ + inline SpirvExpressionStore::SpirvExpressionStore(SpirvWriter& writer) : + m_writer(writer) + { + } +} + +#include diff --git a/include/Nazara/Shader/SpirvPrinter.hpp b/include/Nazara/Shader/SpirvPrinter.hpp new file mode 100644 index 000000000..b3e359fb5 --- /dev/null +++ b/include/Nazara/Shader/SpirvPrinter.hpp @@ -0,0 +1,42 @@ +// Copyright (C) 2020 Jérôme Leclercq +// This file is part of the "Nazara Engine - Shader generator" +// For conditions of distribution and use, see copyright notice in Config.hpp + +#pragma once + +#ifndef NAZARA_SPIRVPRINTER_HPP +#define NAZARA_SPIRVPRINTER_HPP + +#include +#include +#include + +namespace Nz +{ + class NAZARA_SHADER_API SpirvPrinter + { + public: + inline SpirvPrinter(); + SpirvPrinter(const SpirvPrinter&) = default; + SpirvPrinter(SpirvPrinter&&) = default; + ~SpirvPrinter() = default; + + std::string Print(const UInt32* codepoints, std::size_t count); + + SpirvPrinter& operator=(const SpirvPrinter&) = default; + SpirvPrinter& operator=(SpirvPrinter&&) = default; + + private: + void AppendInstruction(); + std::string ReadString(); + UInt32 ReadWord(); + + struct State; + + State* m_currentState; + }; +} + +#include + +#endif diff --git a/include/Nazara/Shader/SpirvPrinter.inl b/include/Nazara/Shader/SpirvPrinter.inl new file mode 100644 index 000000000..f81ee5c21 --- /dev/null +++ b/include/Nazara/Shader/SpirvPrinter.inl @@ -0,0 +1,16 @@ +// Copyright (C) 2020 Jérôme Leclercq +// This file is part of the "Nazara Engine - Shader generator" +// For conditions of distribution and use, see copyright notice in Config.hpp + +#include +#include + +namespace Nz +{ + inline SpirvPrinter::SpirvPrinter() : + m_currentState(nullptr) + { + } +} + +#include diff --git a/include/Nazara/Shader/SpirvStatementVisitor.hpp b/include/Nazara/Shader/SpirvStatementVisitor.hpp new file mode 100644 index 000000000..1ba88942c --- /dev/null +++ b/include/Nazara/Shader/SpirvStatementVisitor.hpp @@ -0,0 +1,43 @@ +// Copyright (C) 2020 Jérôme Leclercq +// This file is part of the "Nazara Engine - Shader generator" +// For conditions of distribution and use, see copyright notice in Config.hpp + +#pragma once + +#ifndef NAZARA_SPIRVSTATEMENTVISITOR_HPP +#define NAZARA_SPIRVSTATEMENTVISITOR_HPP + +#include +#include +#include + +namespace Nz +{ + class SpirvWriter; + + class NAZARA_SHADER_API SpirvStatementVisitor : public ShaderAstVisitorExcept + { + public: + inline SpirvStatementVisitor(SpirvWriter& writer); + SpirvStatementVisitor(const SpirvStatementVisitor&) = delete; + SpirvStatementVisitor(SpirvStatementVisitor&&) = delete; + ~SpirvStatementVisitor() = default; + + using ShaderAstVisitor::Visit; + void Visit(ShaderNodes::AssignOp& node) override; + void Visit(ShaderNodes::Branch& node) override; + void Visit(ShaderNodes::DeclareVariable& node) override; + void Visit(ShaderNodes::ExpressionStatement& node) override; + void Visit(ShaderNodes::StatementBlock& node) override; + + SpirvStatementVisitor& operator=(const SpirvStatementVisitor&) = delete; + SpirvStatementVisitor& operator=(SpirvStatementVisitor&&) = delete; + + private: + SpirvWriter& m_writer; + }; +} + +#include + +#endif diff --git a/include/Nazara/Shader/SpirvStatementVisitor.inl b/include/Nazara/Shader/SpirvStatementVisitor.inl new file mode 100644 index 000000000..fc2274c10 --- /dev/null +++ b/include/Nazara/Shader/SpirvStatementVisitor.inl @@ -0,0 +1,16 @@ +// Copyright (C) 2020 Jérôme Leclercq +// This file is part of the "Nazara Engine - Shader generator" +// For conditions of distribution and use, see copyright notice in Config.hpp + +#include +#include + +namespace Nz +{ + inline SpirvStatementVisitor::SpirvStatementVisitor(SpirvWriter& writer) : + m_writer(writer) + { + } +} + +#include diff --git a/include/Nazara/Shader/SpirvWriter.hpp b/include/Nazara/Shader/SpirvWriter.hpp index 579ad1fef..8bb38087b 100644 --- a/include/Nazara/Shader/SpirvWriter.hpp +++ b/include/Nazara/Shader/SpirvWriter.hpp @@ -23,8 +23,13 @@ namespace Nz { class SpirvSection; - class NAZARA_SHADER_API SpirvWriter : public ShaderAstVisitor, public ShaderVarVisitor + class NAZARA_SHADER_API SpirvWriter { + friend class SpirvExpressionLoad; + friend class SpirvExpressionLoadAccessMember; + friend class SpirvExpressionStore; + friend class SpirvStatementVisitor; + public: struct Environment; @@ -45,49 +50,37 @@ namespace Nz private: struct ExtVar; + struct OnlyCache {}; UInt32 AllocateResultId(); void AppendHeader(); - UInt32 EvaluateExpression(const ShaderNodes::ExpressionPtr& expr); - UInt32 GetConstantId(const ShaderConstantValue& value) const; UInt32 GetFunctionTypeId(ShaderExpressionType retType, const std::vector& parameters); + const ExtVar& GetBuiltinVariable(ShaderNodes::BuiltinEntry builtin) const; + const ExtVar& GetInputVariable(const std::string& name) const; + const ExtVar& GetOutputVariable(const std::string& name) const; + const ExtVar& GetUniformVariable(const std::string& name) const; + SpirvSection& GetInstructions(); UInt32 GetPointerTypeId(const ShaderExpressionType& type, SpirvStorageClass storageClass) const; UInt32 GetTypeId(const ShaderExpressionType& type) const; - void PushResultId(UInt32 value); - UInt32 PopResultId(); - + UInt32 ReadInputVariable(const std::string& name); + std::optional ReadInputVariable(const std::string& name, OnlyCache); + UInt32 ReadLocalVariable(const std::string& name); + std::optional ReadLocalVariable(const std::string& name, OnlyCache); + UInt32 ReadUniformVariable(const std::string& name); + std::optional ReadUniformVariable(const std::string& name, OnlyCache); UInt32 ReadVariable(ExtVar& var); + std::optional ReadVariable(const ExtVar& var, OnlyCache); + UInt32 RegisterConstant(const ShaderConstantValue& value); UInt32 RegisterFunctionType(ShaderExpressionType retType, const std::vector& parameters); UInt32 RegisterPointerType(ShaderExpressionType type, SpirvStorageClass storageClass); UInt32 RegisterType(ShaderExpressionType type); - using ShaderAstVisitor::Visit; - void Visit(ShaderNodes::AccessMember& node) override; - void Visit(ShaderNodes::AssignOp& node) override; - void Visit(ShaderNodes::Branch& node) override; - void Visit(ShaderNodes::BinaryOp& node) override; - void Visit(ShaderNodes::Cast& node) override; - void Visit(ShaderNodes::Constant& node) override; - void Visit(ShaderNodes::DeclareVariable& node) override; - void Visit(ShaderNodes::ExpressionStatement& node) override; - void Visit(ShaderNodes::Identifier& node) override; - void Visit(ShaderNodes::IntrinsicCall& node) override; - void Visit(ShaderNodes::Sample2D& node) override; - void Visit(ShaderNodes::StatementBlock& node) override; - void Visit(ShaderNodes::SwizzleOp& node) override; - - using ShaderVarVisitor::Visit; - void Visit(ShaderNodes::BuiltinVariable& var) override; - void Visit(ShaderNodes::InputVariable& var) override; - void Visit(ShaderNodes::LocalVariable& var) override; - void Visit(ShaderNodes::OutputVariable& var) override; - void Visit(ShaderNodes::ParameterVariable& var) override; - void Visit(ShaderNodes::UniformVariable& var) override; + void WriteLocalVariable(std::string name, UInt32 resultId); static void MergeBlocks(std::vector& output, const SpirvSection& from); @@ -97,6 +90,14 @@ namespace Nz const ShaderAst::Function* currentFunction = nullptr; }; + struct ExtVar + { + UInt32 pointerTypeId; + UInt32 typeId; + UInt32 varId; + std::optional valueId; + }; + struct State; Context m_context; diff --git a/src/Nazara/Shader/ShaderAstVisitorExcept.cpp b/src/Nazara/Shader/ShaderAstVisitorExcept.cpp new file mode 100644 index 000000000..e61fcdb8c --- /dev/null +++ b/src/Nazara/Shader/ShaderAstVisitorExcept.cpp @@ -0,0 +1,75 @@ +// Copyright (C) 2020 Jérôme Leclercq +// This file is part of the "Nazara Engine - Shader generator" +// For conditions of distribution and use, see copyright notice in Config.hpp + +#include +#include +#include + +namespace Nz +{ + void ShaderAstVisitorExcept::Visit(ShaderNodes::AccessMember& node) + { + throw std::runtime_error("unhandled AccessMember node"); + } + + void ShaderAstVisitorExcept::Visit(ShaderNodes::AssignOp& node) + { + throw std::runtime_error("unhandled AssignOp node"); + } + + void ShaderAstVisitorExcept::Visit(ShaderNodes::BinaryOp& node) + { + throw std::runtime_error("unhandled AccessMember node"); + } + + void ShaderAstVisitorExcept::Visit(ShaderNodes::Branch& node) + { + throw std::runtime_error("unhandled Branch node"); + } + + void ShaderAstVisitorExcept::Visit(ShaderNodes::Cast& node) + { + throw std::runtime_error("unhandled Cast node"); + } + + void ShaderAstVisitorExcept::Visit(ShaderNodes::Constant& node) + { + throw std::runtime_error("unhandled Constant node"); + } + + void ShaderAstVisitorExcept::Visit(ShaderNodes::DeclareVariable& node) + { + throw std::runtime_error("unhandled DeclareVariable node"); + } + + void ShaderAstVisitorExcept::Visit(ShaderNodes::ExpressionStatement& node) + { + throw std::runtime_error("unhandled ExpressionStatement node"); + } + + void ShaderAstVisitorExcept::Visit(ShaderNodes::Identifier& node) + { + throw std::runtime_error("unhandled Identifier node"); + } + + void ShaderAstVisitorExcept::Visit(ShaderNodes::IntrinsicCall& node) + { + throw std::runtime_error("unhandled IntrinsicCall node"); + } + + void ShaderAstVisitorExcept::Visit(ShaderNodes::Sample2D& node) + { + throw std::runtime_error("unhandled Sample2D node"); + } + + void ShaderAstVisitorExcept::Visit(ShaderNodes::StatementBlock& node) + { + throw std::runtime_error("unhandled StatementBlock node"); + } + + void ShaderAstVisitorExcept::Visit(ShaderNodes::SwizzleOp& node) + { + throw std::runtime_error("unhandled SwizzleOp node"); + } +} diff --git a/src/Nazara/Shader/ShaderVarVisitorExcept.cpp b/src/Nazara/Shader/ShaderVarVisitorExcept.cpp new file mode 100644 index 000000000..57b5bdddc --- /dev/null +++ b/src/Nazara/Shader/ShaderVarVisitorExcept.cpp @@ -0,0 +1,40 @@ +// Copyright (C) 2020 Jérôme Leclercq +// This file is part of the "Nazara Engine - Shader generator" +// For conditions of distribution and use, see copyright notice in Config.hpp + +#include +#include +#include + +namespace Nz +{ + void ShaderVarVisitorExcept::Visit(ShaderNodes::BuiltinVariable& var) + { + throw std::runtime_error("unhandled BuiltinVariable"); + } + + void ShaderVarVisitorExcept::Visit(ShaderNodes::InputVariable& var) + { + throw std::runtime_error("unhandled InputVariable"); + } + + void ShaderVarVisitorExcept::Visit(ShaderNodes::LocalVariable& var) + { + throw std::runtime_error("unhandled LocalVariable"); + } + + void ShaderVarVisitorExcept::Visit(ShaderNodes::OutputVariable& var) + { + throw std::runtime_error("unhandled OutputVariable"); + } + + void ShaderVarVisitorExcept::Visit(ShaderNodes::ParameterVariable& var) + { + throw std::runtime_error("unhandled ParameterVariable"); + } + + void ShaderVarVisitorExcept::Visit(ShaderNodes::UniformVariable& var) + { + throw std::runtime_error("unhandled UniformVariable"); + } +} diff --git a/src/Nazara/Shader/SpirvExpressionLoad.cpp b/src/Nazara/Shader/SpirvExpressionLoad.cpp new file mode 100644 index 000000000..b032326fb --- /dev/null +++ b/src/Nazara/Shader/SpirvExpressionLoad.cpp @@ -0,0 +1,448 @@ +// Copyright (C) 2020 Jérôme Leclercq +// This file is part of the "Nazara Engine - Shader generator" +// For conditions of distribution and use, see copyright notice in Config.hpp + +#include +#include +#include +#include +#include +#include +#include + +namespace Nz +{ + void SpirvExpressionLoad::Visit(ShaderNodes::AccessMember& node) + { + SpirvExpressionLoadAccessMember accessMemberVisitor(m_writer); + PushResultId(accessMemberVisitor.EvaluateExpression(node)); + } + + void SpirvExpressionLoad::Visit(ShaderNodes::AssignOp& node) + { + SpirvExpressionLoad loadVisitor(m_writer); + SpirvExpressionStore storeVisitor(m_writer); + storeVisitor.Store(node.left, EvaluateExpression(node.right)); + } + + void SpirvExpressionLoad::Visit(ShaderNodes::BinaryOp& node) + { + ShaderExpressionType resultExprType = node.GetExpressionType(); + assert(std::holds_alternative(resultExprType)); + + const ShaderExpressionType& leftExprType = node.left->GetExpressionType(); + assert(std::holds_alternative(leftExprType)); + + const ShaderExpressionType& rightExprType = node.right->GetExpressionType(); + assert(std::holds_alternative(rightExprType)); + + ShaderNodes::BasicType resultType = std::get(resultExprType); + ShaderNodes::BasicType leftType = std::get(leftExprType); + ShaderNodes::BasicType rightType = std::get(rightExprType); + + + UInt32 leftOperand = EvaluateExpression(node.left); + UInt32 rightOperand = EvaluateExpression(node.right); + UInt32 resultId = m_writer.AllocateResultId(); + + bool swapOperands = false; + + SpirvOp op = [&] + { + switch (node.op) + { + case ShaderNodes::BinaryType::Add: + { + switch (leftType) + { + case ShaderNodes::BasicType::Float1: + case ShaderNodes::BasicType::Float2: + case ShaderNodes::BasicType::Float3: + case ShaderNodes::BasicType::Float4: + case ShaderNodes::BasicType::Mat4x4: + return SpirvOp::OpFAdd; + + case ShaderNodes::BasicType::Int1: + case ShaderNodes::BasicType::Int2: + case ShaderNodes::BasicType::Int3: + case ShaderNodes::BasicType::Int4: + case ShaderNodes::BasicType::UInt1: + case ShaderNodes::BasicType::UInt2: + case ShaderNodes::BasicType::UInt3: + case ShaderNodes::BasicType::UInt4: + return SpirvOp::OpIAdd; + + case ShaderNodes::BasicType::Boolean: + case ShaderNodes::BasicType::Sampler2D: + case ShaderNodes::BasicType::Void: + break; + } + } + + case ShaderNodes::BinaryType::Substract: + { + switch (leftType) + { + case ShaderNodes::BasicType::Float1: + case ShaderNodes::BasicType::Float2: + case ShaderNodes::BasicType::Float3: + case ShaderNodes::BasicType::Float4: + case ShaderNodes::BasicType::Mat4x4: + return SpirvOp::OpFSub; + + case ShaderNodes::BasicType::Int1: + case ShaderNodes::BasicType::Int2: + case ShaderNodes::BasicType::Int3: + case ShaderNodes::BasicType::Int4: + case ShaderNodes::BasicType::UInt1: + case ShaderNodes::BasicType::UInt2: + case ShaderNodes::BasicType::UInt3: + case ShaderNodes::BasicType::UInt4: + return SpirvOp::OpISub; + + case ShaderNodes::BasicType::Boolean: + case ShaderNodes::BasicType::Sampler2D: + case ShaderNodes::BasicType::Void: + break; + } + } + + case ShaderNodes::BinaryType::Divide: + { + switch (leftType) + { + case ShaderNodes::BasicType::Float1: + case ShaderNodes::BasicType::Float2: + case ShaderNodes::BasicType::Float3: + case ShaderNodes::BasicType::Float4: + case ShaderNodes::BasicType::Mat4x4: + return SpirvOp::OpFDiv; + + case ShaderNodes::BasicType::Int1: + case ShaderNodes::BasicType::Int2: + case ShaderNodes::BasicType::Int3: + case ShaderNodes::BasicType::Int4: + return SpirvOp::OpSDiv; + + case ShaderNodes::BasicType::UInt1: + case ShaderNodes::BasicType::UInt2: + case ShaderNodes::BasicType::UInt3: + case ShaderNodes::BasicType::UInt4: + return SpirvOp::OpUDiv; + + case ShaderNodes::BasicType::Boolean: + case ShaderNodes::BasicType::Sampler2D: + case ShaderNodes::BasicType::Void: + break; + } + } + + case ShaderNodes::BinaryType::Equality: + { + switch (leftType) + { + case ShaderNodes::BasicType::Boolean: + return SpirvOp::OpLogicalEqual; + + case ShaderNodes::BasicType::Float1: + case ShaderNodes::BasicType::Float2: + case ShaderNodes::BasicType::Float3: + case ShaderNodes::BasicType::Float4: + case ShaderNodes::BasicType::Mat4x4: + return SpirvOp::OpFOrdEqual; + + case ShaderNodes::BasicType::Int1: + case ShaderNodes::BasicType::Int2: + case ShaderNodes::BasicType::Int3: + case ShaderNodes::BasicType::Int4: + case ShaderNodes::BasicType::UInt1: + case ShaderNodes::BasicType::UInt2: + case ShaderNodes::BasicType::UInt3: + case ShaderNodes::BasicType::UInt4: + return SpirvOp::OpIEqual; + + case ShaderNodes::BasicType::Sampler2D: + case ShaderNodes::BasicType::Void: + break; + } + } + + case ShaderNodes::BinaryType::Multiply: + { + switch (leftType) + { + case ShaderNodes::BasicType::Float1: + { + switch (rightType) + { + case ShaderNodes::BasicType::Float1: + return SpirvOp::OpFMul; + + case ShaderNodes::BasicType::Float2: + case ShaderNodes::BasicType::Float3: + case ShaderNodes::BasicType::Float4: + swapOperands = true; + return SpirvOp::OpVectorTimesScalar; + + case ShaderNodes::BasicType::Mat4x4: + swapOperands = true; + return SpirvOp::OpMatrixTimesScalar; + + default: + break; + } + + break; + } + + case ShaderNodes::BasicType::Float2: + case ShaderNodes::BasicType::Float3: + case ShaderNodes::BasicType::Float4: + { + switch (rightType) + { + case ShaderNodes::BasicType::Float1: + return SpirvOp::OpVectorTimesScalar; + + case ShaderNodes::BasicType::Float2: + case ShaderNodes::BasicType::Float3: + case ShaderNodes::BasicType::Float4: + return SpirvOp::OpFMul; + + case ShaderNodes::BasicType::Mat4x4: + return SpirvOp::OpVectorTimesMatrix; + + default: + break; + } + + break; + } + + case ShaderNodes::BasicType::Int1: + case ShaderNodes::BasicType::Int2: + case ShaderNodes::BasicType::Int3: + case ShaderNodes::BasicType::Int4: + case ShaderNodes::BasicType::UInt1: + case ShaderNodes::BasicType::UInt2: + case ShaderNodes::BasicType::UInt3: + case ShaderNodes::BasicType::UInt4: + return SpirvOp::OpIMul; + + case ShaderNodes::BasicType::Mat4x4: + { + switch (rightType) + { + case ShaderNodes::BasicType::Float1: return SpirvOp::OpMatrixTimesScalar; + case ShaderNodes::BasicType::Float4: return SpirvOp::OpMatrixTimesVector; + case ShaderNodes::BasicType::Mat4x4: return SpirvOp::OpMatrixTimesMatrix; + + default: + break; + } + + break; + } + + default: + break; + } + break; + } + } + + assert(false); + throw std::runtime_error("unexpected binary operation"); + }(); + + if (swapOperands) + std::swap(leftOperand, rightOperand); + + m_writer.GetInstructions().Append(op, m_writer.GetTypeId(resultType), resultId, leftOperand, rightOperand); + PushResultId(resultId); + } + + void SpirvExpressionLoad::Visit(ShaderNodes::Cast& node) + { + const ShaderExpressionType& targetExprType = node.exprType; + assert(std::holds_alternative(targetExprType)); + + ShaderNodes::BasicType targetType = std::get(targetExprType); + + StackVector exprResults = NazaraStackVector(UInt32, node.expressions.size()); + + for (const auto& exprPtr : node.expressions) + { + if (!exprPtr) + break; + + exprResults.push_back(EvaluateExpression(exprPtr)); + } + + UInt32 resultId = m_writer.AllocateResultId(); + + m_writer.GetInstructions().AppendVariadic(SpirvOp::OpCompositeConstruct, [&](const auto& appender) + { + appender(m_writer.GetTypeId(targetType)); + appender(resultId); + + for (UInt32 exprResultId : exprResults) + appender(exprResultId); + }); + + PushResultId(resultId); + } + + void SpirvExpressionLoad::Visit(ShaderNodes::Constant& node) + { + std::visit([&] (const auto& value) + { + PushResultId(m_writer.GetConstantId(value)); + }, node.value); + } + + void SpirvExpressionLoad::Visit(ShaderNodes::DeclareVariable& node) + { + if (node.expression) + { + assert(node.variable->GetType() == ShaderNodes::VariableType::LocalVariable); + + const auto& localVar = static_cast(*node.variable); + m_writer.WriteLocalVariable(localVar.name, EvaluateExpression(node.expression)); + } + } + + void SpirvExpressionLoad::Visit(ShaderNodes::ExpressionStatement& node) + { + Visit(node.expression); + PopResultId(); + } + + void SpirvExpressionLoad::Visit(ShaderNodes::Identifier& node) + { + Visit(node.var); + } + + void SpirvExpressionLoad::Visit(ShaderNodes::IntrinsicCall& node) + { + switch (node.intrinsic) + { + case ShaderNodes::IntrinsicType::DotProduct: + { + const ShaderExpressionType& vecExprType = node.parameters[0]->GetExpressionType(); + assert(std::holds_alternative(vecExprType)); + + ShaderNodes::BasicType vecType = std::get(vecExprType); + + UInt32 typeId = m_writer.GetTypeId(node.GetComponentType(vecType)); + + UInt32 vec1 = EvaluateExpression(node.parameters[0]); + UInt32 vec2 = EvaluateExpression(node.parameters[1]); + + UInt32 resultId = m_writer.AllocateResultId(); + + m_writer.GetInstructions().Append(SpirvOp::OpDot, typeId, resultId, vec1, vec2); + PushResultId(resultId); + break; + } + + case ShaderNodes::IntrinsicType::CrossProduct: + default: + throw std::runtime_error("not yet implemented"); + } + } + + void SpirvExpressionLoad::Visit(ShaderNodes::Sample2D& node) + { + UInt32 typeId = m_writer.GetTypeId(ShaderNodes::BasicType::Float4); + + UInt32 samplerId = EvaluateExpression(node.sampler); + UInt32 coordinatesId = EvaluateExpression(node.coordinates); + UInt32 resultId = m_writer.AllocateResultId(); + + m_writer.GetInstructions().Append(SpirvOp::OpImageSampleImplicitLod, typeId, resultId, samplerId, coordinatesId); + PushResultId(resultId); + } + + void SpirvExpressionLoad::Visit(ShaderNodes::SwizzleOp& node) + { + const ShaderExpressionType& targetExprType = node.GetExpressionType(); + assert(std::holds_alternative(targetExprType)); + + ShaderNodes::BasicType targetType = std::get(targetExprType); + + UInt32 exprResultId = EvaluateExpression(node.expression); + UInt32 resultId = m_writer.AllocateResultId(); + + if (node.componentCount > 1) + { + // Swizzling is implemented via SpirvOp::OpVectorShuffle using the same vector twice as operands + m_writer.GetInstructions().AppendVariadic(SpirvOp::OpVectorShuffle, [&](const auto& appender) + { + appender(m_writer.GetTypeId(targetType)); + appender(resultId); + appender(exprResultId); + appender(exprResultId); + + for (std::size_t i = 0; i < node.componentCount; ++i) + appender(UInt32(node.components[0]) - UInt32(node.components[i])); + }); + } + else + { + // Extract a single component from the vector + assert(node.componentCount == 1); + + m_writer.GetInstructions().Append(SpirvOp::OpCompositeExtract, m_writer.GetTypeId(targetType), resultId, exprResultId, UInt32(node.components[0]) - UInt32(ShaderNodes::SwizzleComponent::First) ); + } + + PushResultId(resultId); + } + + void SpirvExpressionLoad::Visit(ShaderNodes::BuiltinVariable& /*var*/) + { + throw std::runtime_error("not implemented yet"); + } + + void SpirvExpressionLoad::Visit(ShaderNodes::InputVariable& var) + { + PushResultId(m_writer.ReadInputVariable(var.name)); + } + + void SpirvExpressionLoad::Visit(ShaderNodes::LocalVariable& var) + { + PushResultId(m_writer.ReadLocalVariable(var.name)); + } + + void SpirvExpressionLoad::Visit(ShaderNodes::ParameterVariable& /*var*/) + { + throw std::runtime_error("not implemented yet"); + } + + void SpirvExpressionLoad::Visit(ShaderNodes::UniformVariable& var) + { + PushResultId(m_writer.ReadUniformVariable(var.name)); + } + + UInt32 SpirvExpressionLoad::EvaluateExpression(const ShaderNodes::ExpressionPtr& expr) + { + Visit(expr); + return PopResultId(); + } + + void SpirvExpressionLoad::PushResultId(UInt32 value) + { + m_resultIds.push_back(value); + } + + UInt32 SpirvExpressionLoad::PopResultId() + { + if (m_resultIds.empty()) + throw std::runtime_error("invalid operation"); + + UInt32 resultId = m_resultIds.back(); + m_resultIds.pop_back(); + + return resultId; + } +} diff --git a/src/Nazara/Shader/SpirvExpressionLoadAccessMember.cpp b/src/Nazara/Shader/SpirvExpressionLoadAccessMember.cpp new file mode 100644 index 000000000..c59bd806d --- /dev/null +++ b/src/Nazara/Shader/SpirvExpressionLoadAccessMember.cpp @@ -0,0 +1,116 @@ +// Copyright (C) 2020 Jérôme Leclercq +// This file is part of the "Nazara Engine - Shader generator" +// For conditions of distribution and use, see copyright notice in Config.hpp + +#include +#include +#include +#include +#include + +namespace Nz +{ + namespace + { + template struct overloaded : Ts... { using Ts::operator()...; }; + template overloaded(Ts...)->overloaded; + } + + UInt32 SpirvExpressionLoadAccessMember::EvaluateExpression(ShaderNodes::AccessMember& expr) + { + Visit(expr); + + return std::visit(overloaded + { + [&](const Pointer& pointer) -> UInt32 + { + UInt32 resultId = m_writer.AllocateResultId(); + + m_writer.GetInstructions().Append(SpirvOp::OpLoad, pointer.pointedTypeId, resultId, pointer.resultId); + + return resultId; + }, + [&](const Value& value) -> UInt32 + { + return value.resultId; + }, + [this](std::monostate) -> UInt32 + { + throw std::runtime_error("an internal error occurred"); + } + }, m_value); + } + + void SpirvExpressionLoadAccessMember::Visit(ShaderNodes::AccessMember& node) + { + Visit(node.structExpr); + + std::visit(overloaded + { + [&](const Pointer& pointer) + { + UInt32 resultId = m_writer.AllocateResultId(); + UInt32 pointerType = m_writer.RegisterPointerType(node.exprType, pointer.storage); //< FIXME + UInt32 typeId = m_writer.GetTypeId(node.exprType); + + m_writer.GetInstructions().AppendVariadic(SpirvOp::OpAccessChain, [&](const auto& appender) + { + appender(pointerType); + appender(resultId); + appender(pointer.resultId); + + for (std::size_t index : node.memberIndices) + appender(m_writer.GetConstantId(Int32(index))); + }); + + m_value = Pointer { pointer.storage, resultId, typeId }; + }, + [&](const Value& value) + { + UInt32 resultId = m_writer.AllocateResultId(); + UInt32 typeId = m_writer.GetTypeId(node.exprType); + + m_writer.GetInstructions().AppendVariadic(SpirvOp::OpCompositeExtract, [&](const auto& appender) + { + appender(typeId); + appender(resultId); + appender(value.resultId); + + for (std::size_t index : node.memberIndices) + appender(m_writer.GetConstantId(Int32(index))); + }); + + m_value = Value { resultId }; + }, + [this](std::monostate) + { + throw std::runtime_error("an internal error occurred"); + } + }, m_value); + } + + void SpirvExpressionLoadAccessMember::Visit(ShaderNodes::Identifier& node) + { + Visit(node.var); + } + + void SpirvExpressionLoadAccessMember::Visit(ShaderNodes::InputVariable& var) + { + auto inputVar = m_writer.GetInputVariable(var.name); + + if (auto resultIdOpt = m_writer.ReadVariable(inputVar, SpirvWriter::OnlyCache{})) + m_value = Value{ *resultIdOpt }; + else + m_value = Pointer{ SpirvStorageClass::Input, inputVar.varId, inputVar.typeId }; + } + + void SpirvExpressionLoadAccessMember::Visit(ShaderNodes::UniformVariable& var) + { + auto uniformVar = m_writer.GetUniformVariable(var.name); + + if (auto resultIdOpt = m_writer.ReadVariable(uniformVar, SpirvWriter::OnlyCache{})) + m_value = Value{ *resultIdOpt }; + else + m_value = Pointer{ SpirvStorageClass::Uniform, uniformVar.varId, uniformVar.typeId }; + } +} diff --git a/src/Nazara/Shader/SpirvExpressionStore.cpp b/src/Nazara/Shader/SpirvExpressionStore.cpp new file mode 100644 index 000000000..109ad53c3 --- /dev/null +++ b/src/Nazara/Shader/SpirvExpressionStore.cpp @@ -0,0 +1,104 @@ +// Copyright (C) 2020 Jérôme Leclercq +// This file is part of the "Nazara Engine - Shader generator" +// For conditions of distribution and use, see copyright notice in Config.hpp + +#include +#include +#include +#include + +namespace Nz +{ + namespace + { + template struct overloaded : Ts... { using Ts::operator()...; }; + template overloaded(Ts...)->overloaded; + } + + void SpirvExpressionStore::Store(const ShaderNodes::ExpressionPtr& node, UInt32 resultId) + { + Visit(node); + + std::visit(overloaded + { + [&](const Pointer& pointer) + { + m_writer.GetInstructions().Append(SpirvOp::OpStore, pointer.resultId, resultId); + }, + [&](const LocalVar& value) + { + m_writer.WriteLocalVariable(value.varName, resultId); + }, + [this](std::monostate) + { + throw std::runtime_error("an internal error occurred"); + } + }, m_value); + } + + void SpirvExpressionStore::Visit(ShaderNodes::AccessMember& node) + { + Visit(node.structExpr); + + std::visit(overloaded + { + [&](const Pointer& pointer) -> UInt32 + { + UInt32 resultId = m_writer.AllocateResultId(); + UInt32 pointerType = m_writer.RegisterPointerType(node.exprType, pointer.storage); //< FIXME + UInt32 typeId = m_writer.GetTypeId(node.exprType); + + m_writer.GetInstructions().AppendVariadic(SpirvOp::OpAccessChain, [&](const auto& appender) + { + appender(pointerType); + appender(resultId); + appender(pointer.resultId); + + for (std::size_t index : node.memberIndices) + appender(m_writer.GetConstantId(Int32(index))); + }); + + m_value = Pointer{ pointer.storage, resultId }; + + return resultId; + }, + [&](const LocalVar& value) -> UInt32 + { + throw std::runtime_error("not yet implemented"); + }, + [this](std::monostate) -> UInt32 + { + throw std::runtime_error("an internal error occurred"); + } + }, m_value); + } + + void SpirvExpressionStore::Visit(ShaderNodes::Identifier& node) + { + Visit(node.var); + } + + void SpirvExpressionStore::Visit(ShaderNodes::SwizzleOp& node) + { + throw std::runtime_error("not yet implemented"); + } + + void SpirvExpressionStore::Visit(ShaderNodes::BuiltinVariable& var) + { + const auto& outputVar = m_writer.GetBuiltinVariable(var.entry); + + m_value = Pointer{ SpirvStorageClass::Output, outputVar.varId }; + } + + void SpirvExpressionStore::Visit(ShaderNodes::LocalVariable& var) + { + m_value = LocalVar{ var.name }; + } + + void SpirvExpressionStore::Visit(ShaderNodes::OutputVariable& var) + { + const auto& outputVar = m_writer.GetOutputVariable(var.name); + + m_value = Pointer{ SpirvStorageClass::Output, outputVar.varId }; + } +} diff --git a/src/Nazara/Shader/SpirvPrinter.cpp b/src/Nazara/Shader/SpirvPrinter.cpp new file mode 100644 index 000000000..4b21ed840 --- /dev/null +++ b/src/Nazara/Shader/SpirvPrinter.cpp @@ -0,0 +1,231 @@ +// Copyright (C) 2020 Jérôme Leclercq +// This file is part of the "Nazara Engine - Shader generator" +// For conditions of distribution and use, see copyright notice in Config.hpp + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace Nz +{ + struct SpirvPrinter::State + { + const UInt32* codepoints; + std::size_t index = 0; + std::size_t count; + std::ostringstream stream; + }; + + std::string SpirvPrinter::Print(const UInt32* codepoints, std::size_t count) + { + State state; + state.codepoints = codepoints; + state.count = count; + + m_currentState = &state; + CallOnExit resetOnExit([&] { m_currentState = nullptr; }); + + UInt32 magicNumber = ReadWord(); + if (magicNumber != SpvMagicNumber) + throw std::runtime_error("invalid Spir-V: magic number didn't match"); + + m_currentState->stream << "Spir-V module\n"; + + UInt32 versionNumber = ReadWord(); + if (versionNumber > SpvVersion) + throw std::runtime_error("Spir-V is more recent than printer, dismissing"); + + UInt8 majorVersion = ((versionNumber) >> 16) & 0xFF; + UInt8 minorVersion = ((versionNumber) >> 8) & 0xFF; + + m_currentState->stream << "Version " + std::to_string(+majorVersion) << "." << std::to_string(+minorVersion) << "\n"; + + UInt32 generatorId = ReadWord(); + + m_currentState->stream << "Generator: " << std::to_string(generatorId) << "\n"; + + UInt32 bound = ReadWord(); + m_currentState->stream << "Bound: " << std::to_string(bound) << "\n"; + + UInt32 schema = ReadWord(); + m_currentState->stream << "Schema: " << std::to_string(schema) << "\n"; + + while (m_currentState->index < m_currentState->count) + AppendInstruction(); + + return m_currentState->stream.str(); + } + + void SpirvPrinter::AppendInstruction() + { + std::size_t startIndex = m_currentState->index; + + UInt32 firstWord = ReadWord(); + + UInt16 wordCount = static_cast((firstWord >> 16) & 0xFFFF); + UInt16 opcode = static_cast(firstWord & 0xFFFF); + + const SpirvInstruction* inst = GetInstructionData(opcode); + if (!inst) + throw std::runtime_error("invalid instruction"); + + m_currentState->stream << inst->name; + + std::size_t currentOperand = 0; + std::size_t instructionEnd = startIndex + wordCount; + while (m_currentState->index < instructionEnd) + { + const SpirvInstruction::Operand* operand = &inst->operands[currentOperand]; + + m_currentState->stream << " " << operand->name << "("; + + switch (operand->kind) + { + case SpirvOperandKind::ImageOperands: + case SpirvOperandKind::FPFastMathMode: + case SpirvOperandKind::SelectionControl: + case SpirvOperandKind::LoopControl: + case SpirvOperandKind::FunctionControl: + case SpirvOperandKind::MemorySemantics: + case SpirvOperandKind::MemoryAccess: + case SpirvOperandKind::KernelProfilingInfo: + case SpirvOperandKind::RayFlags: + case SpirvOperandKind::SourceLanguage: + case SpirvOperandKind::ExecutionModel: + case SpirvOperandKind::AddressingModel: + case SpirvOperandKind::MemoryModel: + case SpirvOperandKind::ExecutionMode: + case SpirvOperandKind::StorageClass: + case SpirvOperandKind::Dim: + case SpirvOperandKind::SamplerAddressingMode: + case SpirvOperandKind::SamplerFilterMode: + case SpirvOperandKind::ImageFormat: + case SpirvOperandKind::ImageChannelOrder: + case SpirvOperandKind::ImageChannelDataType: + case SpirvOperandKind::FPRoundingMode: + case SpirvOperandKind::LinkageType: + case SpirvOperandKind::AccessQualifier: + case SpirvOperandKind::FunctionParameterAttribute: + case SpirvOperandKind::Decoration: + case SpirvOperandKind::BuiltIn: + case SpirvOperandKind::Scope: + case SpirvOperandKind::GroupOperation: + case SpirvOperandKind::KernelEnqueueFlags: + case SpirvOperandKind::Capability: + case SpirvOperandKind::RayQueryIntersection: + case SpirvOperandKind::RayQueryCommittedIntersectionType: + case SpirvOperandKind::RayQueryCandidateIntersectionType: + case SpirvOperandKind::IdResultType: + case SpirvOperandKind::IdResult: + case SpirvOperandKind::IdMemorySemantics: + case SpirvOperandKind::IdScope: + case SpirvOperandKind::IdRef: + case SpirvOperandKind::LiteralInteger: + case SpirvOperandKind::LiteralExtInstInteger: + case SpirvOperandKind::LiteralSpecConstantOpInteger: + case SpirvOperandKind::LiteralContextDependentNumber: //< FIXME + { + UInt32 value = ReadWord(); + m_currentState->stream << value; + break; + } + + case SpirvOperandKind::LiteralString: + { + std::string str = ReadString(); + m_currentState->stream << "\"" << str << "\""; + + /* + std::size_t offset = GetOutputOffset(); + + std::size_t size4 = CountWord(str); + for (std::size_t i = 0; i < size4; ++i) + { + UInt32 codepoint = 0; + for (std::size_t j = 0; j < 4; ++j) + { + std::size_t pos = i * 4 + j; + if (pos < str.size()) + codepoint |= UInt32(str[pos]) << (j * 8); + } + + Append(codepoint); + } + */ + break; + } + + + case SpirvOperandKind::PairLiteralIntegerIdRef: + { + ReadWord(); + ReadWord(); + break; + } + + case SpirvOperandKind::PairIdRefLiteralInteger: + { + ReadWord(); + ReadWord(); + break; + } + + case SpirvOperandKind::PairIdRefIdRef: + { + ReadWord(); + ReadWord(); + break; + } + + /*case SpirvOperandKind::LiteralContextDependentNumber: + { + throw std::runtime_error("not yet implemented"); + }*/ + + default: + break; + + } + + m_currentState->stream << ")"; + + if (currentOperand < inst->minOperandCount - 1) + currentOperand++; + } + + m_currentState->stream << "\n"; + + assert(m_currentState->index == startIndex + wordCount); + } + + std::string SpirvPrinter::ReadString() + { + std::string str; + + for (;;) + { + UInt32 value = ReadWord(); + for (std::size_t j = 0; j < 4; ++j) + { + char c = static_cast((value >> (j * 8)) & 0xFF); + if (c == '\0') + return str; + + str.push_back(c); + } + } + } + + UInt32 SpirvPrinter::ReadWord() + { + if (m_currentState->index >= m_currentState->count) + throw std::runtime_error("unexpected end of stream"); + + return m_currentState->codepoints[m_currentState->index++]; + } +} diff --git a/src/Nazara/Shader/SpirvStatementVisitor.cpp b/src/Nazara/Shader/SpirvStatementVisitor.cpp new file mode 100644 index 000000000..312879375 --- /dev/null +++ b/src/Nazara/Shader/SpirvStatementVisitor.cpp @@ -0,0 +1,49 @@ +// Copyright (C) 2020 Jérôme Leclercq +// This file is part of the "Nazara Engine - Shader generator" +// For conditions of distribution and use, see copyright notice in Config.hpp + +#include +#include +#include +#include +#include + +namespace Nz +{ + void SpirvStatementVisitor::Visit(ShaderNodes::AssignOp& node) + { + SpirvExpressionLoad loadVisitor(m_writer); + SpirvExpressionStore storeVisitor(m_writer); + storeVisitor.Store(node.left, loadVisitor.EvaluateExpression(node.right)); + } + + void SpirvStatementVisitor::Visit(ShaderNodes::Branch& node) + { + throw std::runtime_error("not yet implemented"); + } + + void SpirvStatementVisitor::Visit(ShaderNodes::DeclareVariable& node) + { + if (node.expression) + { + assert(node.variable->GetType() == ShaderNodes::VariableType::LocalVariable); + + const auto& localVar = static_cast(*node.variable); + + SpirvExpressionLoad loadVisitor(m_writer); + m_writer.WriteLocalVariable(localVar.name, loadVisitor.EvaluateExpression(node.expression)); + } + } + + void SpirvStatementVisitor::Visit(ShaderNodes::ExpressionStatement& node) + { + SpirvExpressionLoad loadVisitor(m_writer); + loadVisitor.Visit(node.expression); + } + + void SpirvStatementVisitor::Visit(ShaderNodes::StatementBlock& node) + { + for (auto& statement : node.statements) + Visit(statement); + } +} diff --git a/src/Nazara/Shader/SpirvWriter.cpp b/src/Nazara/Shader/SpirvWriter.cpp index 7a1f41bff..8412a1491 100644 --- a/src/Nazara/Shader/SpirvWriter.cpp +++ b/src/Nazara/Shader/SpirvWriter.cpp @@ -10,6 +10,7 @@ #include #include #include +#include #include #include #include @@ -154,14 +155,6 @@ namespace Nz } } - struct SpirvWriter::ExtVar - { - UInt32 pointerTypeId; - UInt32 typeId; - UInt32 varId; - std::optional valueId; - }; - struct SpirvWriter::State { State() : @@ -387,7 +380,8 @@ namespace Nz state.instructions.Append(SpirvOp::OpFunctionParameter, GetTypeId(param.type), paramResultId); } - Visit(functionStatements[funcIndex]); + SpirvStatementVisitor visitor(*this); + visitor.Visit(functionStatements[funcIndex]); if (func.returnType == ShaderNodes::BasicType::Void) state.instructions.Append(SpirvOp::OpReturn); @@ -480,12 +474,6 @@ namespace Nz m_currentState->header.Append(SpirvOp::OpMemoryModel, SpvAddressingModelLogical, SpvMemoryModelGLSL450); } - UInt32 SpirvWriter::EvaluateExpression(const ShaderNodes::ExpressionPtr& expr) - { - Visit(expr); - return PopResultId(); - } - UInt32 SpirvWriter::GetConstantId(const ShaderConstantValue& value) const { return m_currentState->constantTypeCache.GetId(*SpirvConstantCache::BuildConstant(value)); @@ -507,6 +495,43 @@ namespace Nz }); } + auto SpirvWriter::GetBuiltinVariable(ShaderNodes::BuiltinEntry builtin) const -> const ExtVar& + { + auto it = m_currentState->builtinIds.find(builtin); + assert(it != m_currentState->builtinIds.end()); + + return it->second; + } + + auto SpirvWriter::GetInputVariable(const std::string& name) const -> const ExtVar& + { + auto it = m_currentState->inputIds.find(name); + assert(it != m_currentState->inputIds.end()); + + return it->second; + } + + auto SpirvWriter::GetOutputVariable(const std::string& name) const -> const ExtVar& + { + auto it = m_currentState->outputIds.find(name); + assert(it != m_currentState->outputIds.end()); + + return it->second; + } + + auto SpirvWriter::GetUniformVariable(const std::string& name) const -> const ExtVar& + { + auto it = m_currentState->uniformIds.find(name); + assert(it != m_currentState->uniformIds.end()); + + return it.value(); + } + + SpirvSection& SpirvWriter::GetInstructions() + { + return m_currentState->instructions; + } + UInt32 SpirvWriter::GetPointerTypeId(const ShaderExpressionType& type, SpirvStorageClass storageClass) const { return m_currentState->constantTypeCache.GetId(*SpirvConstantCache::BuildPointerType(*m_context.shader, type, storageClass)); @@ -517,20 +542,53 @@ namespace Nz return m_currentState->constantTypeCache.GetId(*SpirvConstantCache::BuildType(*m_context.shader, type)); } - void SpirvWriter::PushResultId(UInt32 value) + UInt32 SpirvWriter::ReadInputVariable(const std::string& name) { - m_currentState->resultIds.push_back(value); + auto it = m_currentState->inputIds.find(name); + assert(it != m_currentState->inputIds.end()); + + return ReadVariable(it.value()); } - UInt32 SpirvWriter::PopResultId() + std::optional SpirvWriter::ReadInputVariable(const std::string& name, OnlyCache) { - if (m_currentState->resultIds.empty()) - throw std::runtime_error("invalid operation"); + auto it = m_currentState->inputIds.find(name); + assert(it != m_currentState->inputIds.end()); - UInt32 resultId = m_currentState->resultIds.back(); - m_currentState->resultIds.pop_back(); + return ReadVariable(it.value(), OnlyCache{}); + } - return resultId; + UInt32 SpirvWriter::ReadLocalVariable(const std::string& name) + { + auto it = m_currentState->varToResult.find(name); + assert(it != m_currentState->varToResult.end()); + + return it->second; + } + + std::optional SpirvWriter::ReadLocalVariable(const std::string& name, OnlyCache) + { + auto it = m_currentState->varToResult.find(name); + if (it == m_currentState->varToResult.end()) + return {}; + + return it->second; + } + + UInt32 SpirvWriter::ReadUniformVariable(const std::string& name) + { + auto it = m_currentState->uniformIds.find(name); + assert(it != m_currentState->uniformIds.end()); + + return ReadVariable(it.value()); + } + + std::optional SpirvWriter::ReadUniformVariable(const std::string& name, OnlyCache) + { + auto it = m_currentState->uniformIds.find(name); + assert(it != m_currentState->uniformIds.end()); + + return ReadVariable(it.value(), OnlyCache{}); } UInt32 SpirvWriter::ReadVariable(ExtVar& var) @@ -546,6 +604,14 @@ namespace Nz return var.valueId.value(); } + std::optional SpirvWriter::ReadVariable(const ExtVar& var, OnlyCache) + { + if (!var.valueId.has_value()) + return {}; + + return var.valueId.value(); + } + UInt32 SpirvWriter::RegisterConstant(const ShaderConstantValue& value) { return m_currentState->constantTypeCache.Register(*SpirvConstantCache::BuildConstant(value)); @@ -578,572 +644,10 @@ namespace Nz return m_currentState->constantTypeCache.Register(*SpirvConstantCache::BuildType(*m_context.shader, type)); } - void SpirvWriter::Visit(ShaderNodes::AccessMember& node) + void SpirvWriter::WriteLocalVariable(std::string name, UInt32 resultId) { - UInt32 pointerId; - SpirvStorageClass storage; - - switch (node.structExpr->GetType()) - { - case ShaderNodes::NodeType::Identifier: - { - auto& identifier = static_cast(*node.structExpr); - switch (identifier.var->GetType()) - { - case ShaderNodes::VariableType::BuiltinVariable: - { - auto& builtinvar = static_cast(*identifier.var); - auto it = m_currentState->builtinIds.find(builtinvar.entry); - assert(it != m_currentState->builtinIds.end()); - - pointerId = it->second.varId; - break; - } - - case ShaderNodes::VariableType::InputVariable: - { - auto& inputVar = static_cast(*identifier.var); - auto it = m_currentState->inputIds.find(inputVar.name); - assert(it != m_currentState->inputIds.end()); - - storage = SpirvStorageClass::Input; - - pointerId = it->second.varId; - break; - } - - case ShaderNodes::VariableType::OutputVariable: - { - auto& outputVar = static_cast(*identifier.var); - auto it = m_currentState->outputIds.find(outputVar.name); - assert(it != m_currentState->outputIds.end()); - - storage = SpirvStorageClass::Output; - - pointerId = it->second.varId; - break; - } - - case ShaderNodes::VariableType::UniformVariable: - { - auto& uniformVar = static_cast(*identifier.var); - auto it = m_currentState->uniformIds.find(uniformVar.name); - assert(it != m_currentState->uniformIds.end()); - - storage = SpirvStorageClass::Uniform; - - pointerId = it->second.varId; - break; - } - - case ShaderNodes::VariableType::LocalVariable: - case ShaderNodes::VariableType::ParameterVariable: - default: - throw std::runtime_error("not yet implemented"); - } - break; - } - - case ShaderNodes::NodeType::SwizzleOp: //< TODO - default: - throw std::runtime_error("not yet implemented"); - } - - UInt32 memberPointerId = AllocateResultId(); - UInt32 pointerType = RegisterPointerType(node.exprType, storage); //< FIXME - UInt32 typeId = GetTypeId(node.exprType); - - m_currentState->instructions.AppendVariadic(SpirvOp::OpAccessChain, [&](const auto& appender) - { - appender(pointerType); - appender(memberPointerId); - appender(pointerId); - - for (std::size_t index : node.memberIndices) - appender(GetConstantId(Int32(index))); - }); - - UInt32 resultId = AllocateResultId(); - - m_currentState->instructions.Append(SpirvOp::OpLoad, typeId, resultId, memberPointerId); - - PushResultId(resultId); - } - - void SpirvWriter::Visit(ShaderNodes::AssignOp& node) - { - UInt32 result = EvaluateExpression(node.right); - - switch (node.left->GetType()) - { - case ShaderNodes::NodeType::Identifier: - { - auto& identifier = static_cast(*node.left); - switch (identifier.var->GetType()) - { - case ShaderNodes::VariableType::BuiltinVariable: - { - auto& builtinvar = static_cast(*identifier.var); - auto it = m_currentState->builtinIds.find(builtinvar.entry); - assert(it != m_currentState->builtinIds.end()); - - m_currentState->instructions.Append(SpirvOp::OpStore, it->second.varId, result); - PushResultId(result); - break; - } - - case ShaderNodes::VariableType::OutputVariable: - { - auto& outputVar = static_cast(*identifier.var); - auto it = m_currentState->outputIds.find(outputVar.name); - assert(it != m_currentState->outputIds.end()); - - m_currentState->instructions.Append(SpirvOp::OpStore, it->second.varId, result); - PushResultId(result); - break; - } - - case ShaderNodes::VariableType::InputVariable: - case ShaderNodes::VariableType::LocalVariable: - case ShaderNodes::VariableType::ParameterVariable: - case ShaderNodes::VariableType::UniformVariable: - default: - throw std::runtime_error("not yet implemented"); - } - break; - } - - case ShaderNodes::NodeType::SwizzleOp: //< TODO - default: - throw std::runtime_error("not yet implemented"); - } - } - - void SpirvWriter::Visit(ShaderNodes::Branch& node) - { - throw std::runtime_error("not yet implemented"); - } - - void SpirvWriter::Visit(ShaderNodes::BinaryOp& node) - { - ShaderExpressionType resultExprType = node.GetExpressionType(); - assert(std::holds_alternative(resultExprType)); - - const ShaderExpressionType& leftExprType = node.left->GetExpressionType(); - assert(std::holds_alternative(leftExprType)); - - const ShaderExpressionType& rightExprType = node.right->GetExpressionType(); - assert(std::holds_alternative(rightExprType)); - - ShaderNodes::BasicType resultType = std::get(resultExprType); - ShaderNodes::BasicType leftType = std::get(leftExprType); - ShaderNodes::BasicType rightType = std::get(rightExprType); - - - UInt32 leftOperand = EvaluateExpression(node.left); - UInt32 rightOperand = EvaluateExpression(node.right); - UInt32 resultId = AllocateResultId(); - - bool swapOperands = false; - - SpirvOp op = [&] - { - switch (node.op) - { - case ShaderNodes::BinaryType::Add: - { - switch (leftType) - { - case ShaderNodes::BasicType::Float1: - case ShaderNodes::BasicType::Float2: - case ShaderNodes::BasicType::Float3: - case ShaderNodes::BasicType::Float4: - case ShaderNodes::BasicType::Mat4x4: - return SpirvOp::OpFAdd; - - case ShaderNodes::BasicType::Int1: - case ShaderNodes::BasicType::Int2: - case ShaderNodes::BasicType::Int3: - case ShaderNodes::BasicType::Int4: - case ShaderNodes::BasicType::UInt1: - case ShaderNodes::BasicType::UInt2: - case ShaderNodes::BasicType::UInt3: - case ShaderNodes::BasicType::UInt4: - return SpirvOp::OpIAdd; - - case ShaderNodes::BasicType::Boolean: - case ShaderNodes::BasicType::Sampler2D: - case ShaderNodes::BasicType::Void: - break; - } - } - - case ShaderNodes::BinaryType::Substract: - { - switch (leftType) - { - case ShaderNodes::BasicType::Float1: - case ShaderNodes::BasicType::Float2: - case ShaderNodes::BasicType::Float3: - case ShaderNodes::BasicType::Float4: - case ShaderNodes::BasicType::Mat4x4: - return SpirvOp::OpFSub; - - case ShaderNodes::BasicType::Int1: - case ShaderNodes::BasicType::Int2: - case ShaderNodes::BasicType::Int3: - case ShaderNodes::BasicType::Int4: - case ShaderNodes::BasicType::UInt1: - case ShaderNodes::BasicType::UInt2: - case ShaderNodes::BasicType::UInt3: - case ShaderNodes::BasicType::UInt4: - return SpirvOp::OpISub; - - case ShaderNodes::BasicType::Boolean: - case ShaderNodes::BasicType::Sampler2D: - case ShaderNodes::BasicType::Void: - break; - } - } - - case ShaderNodes::BinaryType::Divide: - { - switch (leftType) - { - case ShaderNodes::BasicType::Float1: - case ShaderNodes::BasicType::Float2: - case ShaderNodes::BasicType::Float3: - case ShaderNodes::BasicType::Float4: - case ShaderNodes::BasicType::Mat4x4: - return SpirvOp::OpFDiv; - - case ShaderNodes::BasicType::Int1: - case ShaderNodes::BasicType::Int2: - case ShaderNodes::BasicType::Int3: - case ShaderNodes::BasicType::Int4: - return SpirvOp::OpSDiv; - - case ShaderNodes::BasicType::UInt1: - case ShaderNodes::BasicType::UInt2: - case ShaderNodes::BasicType::UInt3: - case ShaderNodes::BasicType::UInt4: - return SpirvOp::OpUDiv; - - case ShaderNodes::BasicType::Boolean: - case ShaderNodes::BasicType::Sampler2D: - case ShaderNodes::BasicType::Void: - break; - } - } - - case ShaderNodes::BinaryType::Equality: - { - switch (leftType) - { - case ShaderNodes::BasicType::Boolean: - return SpirvOp::OpLogicalEqual; - - case ShaderNodes::BasicType::Float1: - case ShaderNodes::BasicType::Float2: - case ShaderNodes::BasicType::Float3: - case ShaderNodes::BasicType::Float4: - case ShaderNodes::BasicType::Mat4x4: - return SpirvOp::OpFOrdEqual; - - case ShaderNodes::BasicType::Int1: - case ShaderNodes::BasicType::Int2: - case ShaderNodes::BasicType::Int3: - case ShaderNodes::BasicType::Int4: - case ShaderNodes::BasicType::UInt1: - case ShaderNodes::BasicType::UInt2: - case ShaderNodes::BasicType::UInt3: - case ShaderNodes::BasicType::UInt4: - return SpirvOp::OpIEqual; - - case ShaderNodes::BasicType::Sampler2D: - case ShaderNodes::BasicType::Void: - break; - } - } - - case ShaderNodes::BinaryType::Multiply: - { - switch (leftType) - { - case ShaderNodes::BasicType::Float1: - { - switch (rightType) - { - case ShaderNodes::BasicType::Float1: - return SpirvOp::OpFMul; - - case ShaderNodes::BasicType::Float2: - case ShaderNodes::BasicType::Float3: - case ShaderNodes::BasicType::Float4: - swapOperands = true; - return SpirvOp::OpVectorTimesScalar; - - case ShaderNodes::BasicType::Mat4x4: - swapOperands = true; - return SpirvOp::OpMatrixTimesScalar; - - default: - break; - } - - break; - } - - case ShaderNodes::BasicType::Float2: - case ShaderNodes::BasicType::Float3: - case ShaderNodes::BasicType::Float4: - { - switch (rightType) - { - case ShaderNodes::BasicType::Float1: - return SpirvOp::OpVectorTimesScalar; - - case ShaderNodes::BasicType::Float2: - case ShaderNodes::BasicType::Float3: - case ShaderNodes::BasicType::Float4: - return SpirvOp::OpFMul; - - case ShaderNodes::BasicType::Mat4x4: - return SpirvOp::OpVectorTimesMatrix; - - default: - break; - } - - break; - } - - case ShaderNodes::BasicType::Int1: - case ShaderNodes::BasicType::Int2: - case ShaderNodes::BasicType::Int3: - case ShaderNodes::BasicType::Int4: - case ShaderNodes::BasicType::UInt1: - case ShaderNodes::BasicType::UInt2: - case ShaderNodes::BasicType::UInt3: - case ShaderNodes::BasicType::UInt4: - return SpirvOp::OpIMul; - - case ShaderNodes::BasicType::Mat4x4: - { - switch (rightType) - { - case ShaderNodes::BasicType::Float1: return SpirvOp::OpMatrixTimesScalar; - case ShaderNodes::BasicType::Float4: return SpirvOp::OpMatrixTimesVector; - case ShaderNodes::BasicType::Mat4x4: return SpirvOp::OpMatrixTimesMatrix; - - default: - break; - } - - break; - } - - default: - break; - } - break; - } - } - - assert(false); - throw std::runtime_error("unexpected binary operation"); - }(); - - if (swapOperands) - std::swap(leftOperand, rightOperand); - - m_currentState->instructions.Append(op, GetTypeId(resultType), resultId, leftOperand, rightOperand); - PushResultId(resultId); - } - - void SpirvWriter::Visit(ShaderNodes::Cast& node) - { - const ShaderExpressionType& targetExprType = node.exprType; - assert(std::holds_alternative(targetExprType)); - - ShaderNodes::BasicType targetType = std::get(targetExprType); - - StackVector exprResults = NazaraStackVector(UInt32, node.expressions.size()); - - for (const auto& exprPtr : node.expressions) - { - if (!exprPtr) - break; - - exprResults.push_back(EvaluateExpression(exprPtr)); - } - - UInt32 resultId = AllocateResultId(); - - m_currentState->instructions.AppendVariadic(SpirvOp::OpCompositeConstruct, [&](const auto& appender) - { - appender(GetTypeId(targetType)); - appender(resultId); - - for (UInt32 exprResultId : exprResults) - appender(exprResultId); - }); - - PushResultId(resultId); - } - - void SpirvWriter::Visit(ShaderNodes::Constant& node) - { - std::visit([&] (const auto& value) - { - PushResultId(GetConstantId(value)); - }, node.value); - } - - void SpirvWriter::Visit(ShaderNodes::DeclareVariable& node) - { - if (node.expression) - { - assert(node.variable->GetType() == ShaderNodes::VariableType::LocalVariable); - - const auto& localVar = static_cast(*node.variable); - m_currentState->varToResult[localVar.name] = EvaluateExpression(node.expression); - } - } - - void SpirvWriter::Visit(ShaderNodes::ExpressionStatement& node) - { - Visit(node.expression); - PopResultId(); - } - - void SpirvWriter::Visit(ShaderNodes::Identifier& node) - { - Visit(node.var); - } - - void SpirvWriter::Visit(ShaderNodes::IntrinsicCall& node) - { - switch (node.intrinsic) - { - case ShaderNodes::IntrinsicType::DotProduct: - { - const ShaderExpressionType& vecExprType = node.parameters[0]->GetExpressionType(); - assert(std::holds_alternative(vecExprType)); - - ShaderNodes::BasicType vecType = std::get(vecExprType); - - UInt32 typeId = GetTypeId(node.GetComponentType(vecType)); - - UInt32 vec1 = EvaluateExpression(node.parameters[0]); - UInt32 vec2 = EvaluateExpression(node.parameters[1]); - - UInt32 resultId = AllocateResultId(); - - m_currentState->instructions.Append(SpirvOp::OpDot, typeId, resultId, vec1, vec2); - PushResultId(resultId); - break; - } - - case ShaderNodes::IntrinsicType::CrossProduct: - default: - throw std::runtime_error("not yet implemented"); - } - } - - void SpirvWriter::Visit(ShaderNodes::Sample2D& node) - { - UInt32 typeId = GetTypeId(ShaderNodes::BasicType::Float4); - - UInt32 samplerId = EvaluateExpression(node.sampler); - UInt32 coordinatesId = EvaluateExpression(node.coordinates); - UInt32 resultId = AllocateResultId(); - - m_currentState->instructions.Append(SpirvOp::OpImageSampleImplicitLod, typeId, resultId, samplerId, coordinatesId); - PushResultId(resultId); - } - - void SpirvWriter::Visit(ShaderNodes::StatementBlock& node) - { - for (auto& statement : node.statements) - Visit(statement); - } - - void SpirvWriter::Visit(ShaderNodes::SwizzleOp& node) - { - const ShaderExpressionType& targetExprType = node.GetExpressionType(); - assert(std::holds_alternative(targetExprType)); - - ShaderNodes::BasicType targetType = std::get(targetExprType); - - UInt32 exprResultId = EvaluateExpression(node.expression); - UInt32 resultId = AllocateResultId(); - - if (node.componentCount > 1) - { - // Swizzling is implemented via SpirvOp::OpVectorShuffle using the same vector twice as operands - m_currentState->instructions.AppendVariadic(SpirvOp::OpVectorShuffle, [&](const auto& appender) - { - appender(GetTypeId(targetType)); - appender(resultId); - appender(exprResultId); - appender(exprResultId); - - for (std::size_t i = 0; i < node.componentCount; ++i) - appender(UInt32(node.components[0]) - UInt32(node.components[i])); - }); - } - else - { - // Extract a single component from the vector - assert(node.componentCount == 1); - - m_currentState->instructions.Append(SpirvOp::OpCompositeExtract, GetTypeId(targetType), resultId, exprResultId, UInt32(node.components[0]) - UInt32(ShaderNodes::SwizzleComponent::First) ); - } - - PushResultId(resultId); - } - - void SpirvWriter::Visit(ShaderNodes::BuiltinVariable& var) - { - throw std::runtime_error("not implemented yet"); - } - - void SpirvWriter::Visit(ShaderNodes::InputVariable& var) - { - auto it = m_currentState->inputIds.find(var.name); - assert(it != m_currentState->inputIds.end()); - - PushResultId(ReadVariable(it.value())); - } - - void SpirvWriter::Visit(ShaderNodes::LocalVariable& var) - { - auto it = m_currentState->varToResult.find(var.name); - assert(it != m_currentState->varToResult.end()); - - PushResultId(it->second); - } - - void SpirvWriter::Visit(ShaderNodes::OutputVariable& var) - { - auto it = m_currentState->outputIds.find(var.name); - assert(it != m_currentState->outputIds.end()); - - PushResultId(ReadVariable(it.value())); - } - - void SpirvWriter::Visit(ShaderNodes::ParameterVariable& var) - { - throw std::runtime_error("not implemented yet"); - } - - void SpirvWriter::Visit(ShaderNodes::UniformVariable& var) - { - auto it = m_currentState->uniformIds.find(var.name); - assert(it != m_currentState->uniformIds.end()); - - PushResultId(ReadVariable(it.value())); + assert(m_currentState); + m_currentState->varToResult.insert_or_assign(std::move(name), resultId); } void SpirvWriter::MergeBlocks(std::vector& output, const SpirvSection& from)