ShaderLang: Proof of concept (add support for a lot of things)

This commit is contained in:
Jérôme Leclercq 2021-03-31 10:21:35 +02:00
parent 2a73005295
commit c1d1838336
37 changed files with 2259 additions and 908 deletions

View File

@ -37,12 +37,12 @@
#include <Nazara/Shader/ShaderAstExpressionType.hpp>
#include <Nazara/Shader/ShaderAstExpressionVisitor.hpp>
#include <Nazara/Shader/ShaderAstExpressionVisitorExcept.hpp>
#include <Nazara/Shader/ShaderAstNodes.hpp>
#include <Nazara/Shader/ShaderAstOptimizer.hpp>
#include <Nazara/Shader/ShaderAstRecursiveVisitor.hpp>
#include <Nazara/Shader/ShaderAstSerializer.hpp>
#include <Nazara/Shader/ShaderAstStatementVisitor.hpp>
#include <Nazara/Shader/ShaderAstStatementVisitorExcept.hpp>
#include <Nazara/Shader/ShaderAstTypes.hpp>
#include <Nazara/Shader/ShaderAstUtils.hpp>
#include <Nazara/Shader/ShaderAstValidator.hpp>
#include <Nazara/Shader/ShaderBuilder.hpp>

View File

@ -0,0 +1,24 @@
// Copyright (C) 2020 Jérôme Leclercq
// This file is part of the "Nazara Engine - Shader generator"
// For conditions of distribution and use, see copyright notice in Config.hpp
#pragma once
#ifndef NAZARA_SHADERAST_ATTRIBUTES_HPP
#define NAZARA_SHADERAST_ATTRIBUTES_HPP
#include <Nazara/Prerequisites.hpp>
#include <Nazara/Shader/ShaderEnums.hpp>
namespace Nz::ShaderAst
{
struct Attribute
{
using Param = std::variant<std::monostate, long long, std::string>;
AttributeType type;
Param args;
};
}
#endif

View File

@ -0,0 +1,96 @@
// Copyright (C) 2020 Jérôme Leclercq
// This file is part of the "Nazara Engine - Shader generator"
// For conditions of distribution and use, see copyright notice in Config.hpp
#pragma once
#ifndef NAZARA_SHADER_AST_EXPRESSIONTYPE_HPP
#define NAZARA_SHADER_AST_EXPRESSIONTYPE_HPP
#include <Nazara/Prerequisites.hpp>
#include <Nazara/Utility/Enums.hpp>
#include <Nazara/Shader/ShaderEnums.hpp>
#include <Nazara/Shader/Ast/Attribute.hpp>
#include <string>
#include <variant>
#include <vector>
namespace Nz::ShaderAst
{
struct IdentifierType //< Alias or struct
{
std::string name;
inline bool operator==(const IdentifierType& rhs) const;
inline bool operator!=(const IdentifierType& rhs) const;
};
struct MatrixType
{
std::size_t columnCount;
std::size_t rowCount;
PrimitiveType type;
inline bool operator==(const MatrixType& rhs) const;
inline bool operator!=(const MatrixType& rhs) const;
};
struct NoType
{
inline bool operator==(const NoType& rhs) const;
inline bool operator!=(const NoType& rhs) const;
};
struct SamplerType
{
ImageType dim;
PrimitiveType sampledType;
inline bool operator==(const SamplerType& rhs) const;
inline bool operator!=(const SamplerType& rhs) const;
};
struct UniformType
{
IdentifierType containedType;
inline bool operator==(const UniformType& rhs) const;
inline bool operator!=(const UniformType& rhs) const;
};
struct VectorType
{
std::size_t componentCount;
PrimitiveType type;
inline bool operator==(const VectorType& rhs) const;
inline bool operator!=(const VectorType& rhs) const;
};
using ExpressionType = std::variant<NoType, IdentifierType, PrimitiveType, MatrixType, SamplerType, UniformType, VectorType>;
struct StructDescription
{
struct StructMember
{
std::string name;
std::vector<Attribute> attributes;
ExpressionType type;
};
std::string name;
std::vector<StructMember> members;
};
inline bool IsIdentifierType(const ExpressionType& type);
inline bool IsMatrixType(const ExpressionType& type);
inline bool IsNoType(const ExpressionType& type);
inline bool IsPrimitiveType(const ExpressionType& type);
inline bool IsSamplerType(const ExpressionType& type);
inline bool IsUniformType(const ExpressionType& type);
inline bool IsVectorType(const ExpressionType& type);
}
#include <Nazara/Shader/Ast/ExpressionType.inl>
#endif

View File

@ -0,0 +1,111 @@
// Copyright (C) 2020 Jérôme Leclercq
// This file is part of the "Nazara Engine - Shader generator"
// For conditions of distribution and use, see copyright notice in Config.hpp
#include <Nazara/Shader/Ast/ExpressionType.hpp>
#include <Nazara/Core/Algorithm.hpp>
#include <Nazara/Shader/Debug.hpp>
namespace Nz::ShaderAst
{
inline bool IdentifierType::operator==(const IdentifierType& rhs) const
{
return name == rhs.name;
}
inline bool IdentifierType::operator!=(const IdentifierType& rhs) const
{
return !operator==(rhs);
}
inline bool MatrixType::operator==(const MatrixType& rhs) const
{
return columnCount == rhs.columnCount && rowCount == rhs.rowCount && type == rhs.type;
}
inline bool MatrixType::operator!=(const MatrixType& rhs) const
{
return !operator==(rhs);
}
inline bool NoType::operator==(const NoType& /*rhs*/) const
{
return true;
}
inline bool NoType::operator!=(const NoType& /*rhs*/) const
{
return false;
}
inline bool SamplerType::operator==(const SamplerType& rhs) const
{
return dim == rhs.dim && sampledType == rhs.sampledType;
}
inline bool SamplerType::operator!=(const SamplerType& rhs) const
{
return !operator==(rhs);
}
inline bool UniformType::operator==(const UniformType& rhs) const
{
return containedType == rhs.containedType;
}
inline bool UniformType::operator!=(const UniformType& rhs) const
{
return !operator==(rhs);
}
inline bool VectorType::operator==(const VectorType& rhs) const
{
return componentCount == rhs.componentCount && type == rhs.type;
}
inline bool VectorType::operator!=(const VectorType& rhs) const
{
return !operator==(rhs);
}
inline bool IsIdentifierType(const ExpressionType& type)
{
return std::holds_alternative<IdentifierType>(type);
}
inline bool IsMatrixType(const ExpressionType& type)
{
return std::holds_alternative<MatrixType>(type);
}
inline bool IsNoType(const ExpressionType& type)
{
return std::holds_alternative<NoType>(type);
}
inline bool IsPrimitiveType(const ExpressionType& type)
{
return std::holds_alternative<PrimitiveType>(type);
}
inline bool IsSamplerType(const ExpressionType& type)
{
return std::holds_alternative<SamplerType>(type);
}
bool IsUniformType(const ExpressionType& type)
{
return std::holds_alternative<UniformType>(type);
}
bool IsVectorType(const ExpressionType& type)
{
return std::holds_alternative<VectorType>(type);
}
}
#include <Nazara/Shader/DebugOff.hpp>

View File

@ -28,7 +28,7 @@ namespace Nz
GlslWriter(GlslWriter&&) = delete;
~GlslWriter() = default;
std::string Generate(ShaderAst::StatementPtr& shader, const States& conditions = {});
std::string Generate(ShaderStageType shaderStage, ShaderAst::StatementPtr& shader, const States& conditions = {});
void SetEnv(Environment environment);
@ -44,17 +44,26 @@ namespace Nz
static const char* GetFlipYUniformName();
private:
void Append(ShaderAst::ShaderExpressionType type);
void Append(const ShaderAst::ExpressionType& type);
void Append(ShaderAst::BuiltinEntry builtin);
void Append(ShaderAst::BasicType type);
void Append(const ShaderAst::IdentifierType& identifierType);
void Append(const ShaderAst::MatrixType& matrixType);
void Append(ShaderAst::MemoryLayout layout);
void Append(ShaderAst::NoType);
void Append(ShaderAst::PrimitiveType type);
void Append(const ShaderAst::SamplerType& samplerType);
void Append(const ShaderAst::UniformType& uniformType);
void Append(const ShaderAst::VectorType& vecType);
template<typename T> void Append(const T& param);
template<typename T1, typename T2, typename... Args> void Append(const T1& firstParam, const T2& secondParam, Args&&... params);
void AppendCommentSection(const std::string& section);
void AppendEntryPoint(ShaderStageType shaderStage);
void AppendField(std::size_t scopeId, const std::string& structName, const std::string* memberIdentifier, std::size_t remainingMembers);
void AppendLine(const std::string& txt = {});
template<typename... Args> void AppendLine(Args&&... params);
void EnterScope();
void LeaveScope();
void LeaveScope(bool skipLine = true);
void Visit(ShaderAst::ExpressionPtr& expr, bool encloseIfRequired = false);
@ -70,7 +79,9 @@ namespace Nz
void Visit(ShaderAst::BranchStatement& node) override;
void Visit(ShaderAst::ConditionalStatement& node) override;
void Visit(ShaderAst::DeclareExternalStatement& node) override;
void Visit(ShaderAst::DeclareFunctionStatement& node) override;
void Visit(ShaderAst::DeclareStructStatement& node) override;
void Visit(ShaderAst::DeclareVariableStatement& node) override;
void Visit(ShaderAst::DiscardStatement& node) override;
void Visit(ShaderAst::ExpressionStatement& node) override;

View File

@ -16,15 +16,22 @@ namespace Nz::ShaderAst
{
struct AstCache
{
struct Identifier;
struct Alias
{
std::variant<ExpressionType> value;
};
struct Variable
{
ShaderExpressionType type;
ExpressionType type;
};
struct Identifier
{
std::string name;
std::variant<Variable, StructDescription> value;
std::variant<Alias, Variable, StructDescription> value;
};
struct Scope
@ -33,12 +40,12 @@ namespace Nz::ShaderAst
std::vector<Identifier> identifiers;
};
inline void Clear();
inline const Identifier* FindIdentifier(std::size_t startingScopeId, const std::string& identifierName) const;
inline std::size_t GetScopeId(const Node* node) const;
ShaderStageType stageType = ShaderStageType::Undefined;
std::array<DeclareFunctionStatement*, ShaderStageTypeCount> entryFunctions = {};
std::unordered_map<const Expression*, ShaderExpressionType> nodeExpressionType;
std::unordered_map<const Expression*, ExpressionType> nodeExpressionType;
std::unordered_map<const Node*, std::size_t> scopeIdByNode;
std::vector<Scope> scopes;
};

View File

@ -7,6 +7,14 @@
namespace Nz::ShaderAst
{
inline void AstCache::Clear()
{
entryFunctions.fill(nullptr);
nodeExpressionType.clear();
scopeIdByNode.clear();
scopes.clear();
}
inline auto AstCache::FindIdentifier(std::size_t startingScopeId, const std::string& identifierName) const -> const Identifier*
{
assert(startingScopeId < scopes.size());
@ -28,7 +36,7 @@ namespace Nz::ShaderAst
inline std::size_t AstCache::GetScopeId(const Node* node) const
{
auto it = scopeIdByNode.find(node);
assert(it == scopeIdByNode.end());
assert(it != scopeIdByNode.end());
return it->second;
}

View File

@ -33,6 +33,8 @@ namespace Nz::ShaderAst
ExpressionPtr CloneExpression(ExpressionPtr& expr);
StatementPtr CloneStatement(StatementPtr& statement);
virtual std::unique_ptr<DeclareFunctionStatement> Clone(DeclareFunctionStatement& node);
using AstExpressionVisitor::Visit;
using AstStatementVisitor::Visit;
@ -45,8 +47,10 @@ namespace Nz::ShaderAst
void Visit(IdentifierExpression& node) override;
void Visit(IntrinsicExpression& node) override;
void Visit(SwizzleExpression& node) override;
void Visit(BranchStatement& node) override;
void Visit(ConditionalStatement& node) override;
void Visit(DeclareExternalStatement& node) override;
void Visit(DeclareFunctionStatement& node) override;
void Visit(DeclareStructStatement& node) override;
void Visit(DeclareVariableStatement& node) override;

View File

@ -10,7 +10,7 @@
#include <Nazara/Prerequisites.hpp>
#include <Nazara/Shader/Config.hpp>
#include <Nazara/Shader/ShaderAstExpressionVisitor.hpp>
#include <Nazara/Shader/ShaderAstTypes.hpp>
#include <Nazara/Shader/Ast/ExpressionType.hpp>
#include <vector>
namespace Nz::ShaderAst
@ -25,13 +25,14 @@ namespace Nz::ShaderAst
ExpressionTypeVisitor(ExpressionTypeVisitor&&) = delete;
~ExpressionTypeVisitor() = default;
ShaderExpressionType GetExpressionType(Expression& expression, AstCache* cache);
ExpressionType GetExpressionType(Expression& expression, AstCache* cache);
ExpressionTypeVisitor& operator=(const ExpressionTypeVisitor&) = delete;
ExpressionTypeVisitor& operator=(ExpressionTypeVisitor&&) = delete;
private:
ShaderExpressionType GetExpressionTypeInternal(Expression& expression);
ExpressionType GetExpressionTypeInternal(Expression& expression);
ExpressionType ResolveAlias(Expression& expression, ExpressionType expressionType);
void Visit(Expression& expression);
@ -46,10 +47,10 @@ namespace Nz::ShaderAst
void Visit(SwizzleExpression& node) override;
AstCache* m_cache;
std::optional<ShaderExpressionType> m_lastExpressionType;
std::optional<ExpressionType> m_lastExpressionType;
};
inline ShaderExpressionType GetExpressionType(Expression& expression, AstCache* cache = nullptr);
inline ExpressionType GetExpressionType(Expression& expression, AstCache* cache = nullptr);
}
#include <Nazara/Shader/ShaderAstExpressionType.inl>

View File

@ -7,7 +7,7 @@
namespace Nz::ShaderAst
{
inline ShaderExpressionType GetExpressionType(Expression& expression, AstCache* cache)
inline ExpressionType GetExpressionType(Expression& expression, AstCache* cache)
{
ExpressionTypeVisitor visitor;
return visitor.GetExpressionType(expression, cache);

View File

@ -37,6 +37,7 @@ NAZARA_SHADERAST_EXPRESSION(IntrinsicExpression)
NAZARA_SHADERAST_EXPRESSION(SwizzleExpression)
NAZARA_SHADERAST_STATEMENT(BranchStatement)
NAZARA_SHADERAST_STATEMENT(ConditionalStatement)
NAZARA_SHADERAST_STATEMENT(DeclareExternalStatement)
NAZARA_SHADERAST_STATEMENT(DeclareFunctionStatement)
NAZARA_SHADERAST_STATEMENT(DeclareStructStatement)
NAZARA_SHADERAST_STATEMENT(DeclareVariableStatement)

View File

@ -32,6 +32,7 @@ namespace Nz::ShaderAst
void Visit(BranchStatement& node) override;
void Visit(ConditionalStatement& node) override;
void Visit(DeclareExternalStatement& node) override;
void Visit(DeclareFunctionStatement& node) override;
void Visit(DeclareStructStatement& node) override;
void Visit(DeclareVariableStatement& node) override;

View File

@ -35,6 +35,7 @@ namespace Nz::ShaderAst
void Serialize(BranchStatement& node);
void Serialize(ConditionalStatement& node);
void Serialize(DeclareExternalStatement& node);
void Serialize(DeclareFunctionStatement& node);
void Serialize(DeclareStructStatement& node);
void Serialize(DeclareVariableStatement& node);
@ -45,6 +46,7 @@ namespace Nz::ShaderAst
void Serialize(ReturnStatement& node);
protected:
void Attributes(std::vector<Attribute>& attributes);
template<typename T> void Container(T& container);
template<typename T> void Enum(T& enumVal);
template<typename T> void OptEnum(std::optional<T>& optVal);
@ -55,7 +57,7 @@ namespace Nz::ShaderAst
virtual void Node(ExpressionPtr& node) = 0;
virtual void Node(StatementPtr& node) = 0;
virtual void Type(ShaderExpressionType& type) = 0;
virtual void Type(ExpressionType& type) = 0;
virtual void Value(bool& val) = 0;
virtual void Value(float& val) = 0;
@ -86,7 +88,7 @@ namespace Nz::ShaderAst
bool IsWriting() const override;
void Node(ExpressionPtr& node) override;
void Node(StatementPtr& node) override;
void Type(ShaderExpressionType& type) override;
void Type(ExpressionType& type) override;
void Value(bool& val) override;
void Value(float& val) override;
void Value(std::string& val) override;
@ -117,7 +119,7 @@ namespace Nz::ShaderAst
bool IsWriting() const override;
void Node(ExpressionPtr& node) override;
void Node(StatementPtr& node) override;
void Type(ShaderExpressionType& type) override;
void Type(ExpressionType& type) override;
void Value(bool& val) override;
void Value(float& val) override;
void Value(std::string& val) override;

View File

@ -1,40 +0,0 @@
// Copyright (C) 2020 Jérôme Leclercq
// This file is part of the "Nazara Engine - Shader generator"
// For conditions of distribution and use, see copyright notice in Config.hpp
#pragma once
#ifndef NAZARA_SHADER_ASTTYPES_HPP
#define NAZARA_SHADER_ASTTYPES_HPP
#include <Nazara/Prerequisites.hpp>
#include <Nazara/Shader/ShaderEnums.hpp>
#include <string>
#include <variant>
#include <vector>
namespace Nz::ShaderAst
{
using ShaderExpressionType = std::variant<BasicType, std::string>;
struct StructDescription
{
struct StructMember
{
std::string name;
ShaderExpressionType type;
};
std::string name;
std::vector<StructMember> members;
};
inline bool IsBasicType(const ShaderExpressionType& type);
inline bool IsMatrixType(const ShaderExpressionType& type);
inline bool IsSamplerType(const ShaderExpressionType& type);
inline bool IsStructType(const ShaderExpressionType& type);
}
#include <Nazara/Shader/ShaderAstTypes.inl>
#endif // NAZARA_SHADER_ASTTYPES_HPP

View File

@ -1,104 +0,0 @@
// Copyright (C) 2020 Jérôme Leclercq
// This file is part of the "Nazara Engine - Shader generator"
// For conditions of distribution and use, see copyright notice in Config.hpp
#include <Nazara/Shader/ShaderAstTypes.hpp>
#include <Nazara/Core/Algorithm.hpp>
#include <Nazara/Shader/Debug.hpp>
namespace Nz::ShaderAst
{
inline bool IsBasicType(const ShaderExpressionType& type)
{
return std::visit([&](auto&& arg)
{
using T = std::decay_t<decltype(arg)>;
if constexpr (std::is_same_v<T, BasicType>)
return true;
else if constexpr (std::is_same_v<T, std::string>)
return false;
else
static_assert(AlwaysFalse<T>::value, "non-exhaustive visitor");
}, type);
}
inline bool IsMatrixType(const ShaderExpressionType& type)
{
if (!IsBasicType(type))
return false;
switch (std::get<BasicType>(type))
{
case BasicType::Mat4x4:
return true;
case BasicType::Boolean:
case BasicType::Float1:
case BasicType::Float2:
case BasicType::Float3:
case BasicType::Float4:
case BasicType::Int1:
case BasicType::Int2:
case BasicType::Int3:
case BasicType::Int4:
case BasicType::Sampler2D:
case BasicType::Void:
case BasicType::UInt1:
case BasicType::UInt2:
case BasicType::UInt3:
case BasicType::UInt4:
return false;
}
return false;
}
inline bool IsSamplerType(const ShaderExpressionType& type)
{
if (!IsBasicType(type))
return false;
switch (std::get<BasicType>(type))
{
case BasicType::Sampler2D:
return true;
case BasicType::Boolean:
case BasicType::Float1:
case BasicType::Float2:
case BasicType::Float3:
case BasicType::Float4:
case BasicType::Int1:
case BasicType::Int2:
case BasicType::Int3:
case BasicType::Int4:
case BasicType::Mat4x4:
case BasicType::Void:
case BasicType::UInt1:
case BasicType::UInt2:
case BasicType::UInt3:
case BasicType::UInt4:
return false;
}
return false;
}
inline bool IsStructType(const ShaderExpressionType& type)
{
return std::visit([&](auto&& arg)
{
using T = std::decay_t<decltype(arg)>;
if constexpr (std::is_same_v<T, BasicType>)
return false;
else if constexpr (std::is_same_v<T, std::string>)
return true;
else
static_assert(AlwaysFalse<T>::value, "non-exhaustive visitor");
}, type);
}
}
#include <Nazara/Shader/DebugOff.hpp>

View File

@ -29,14 +29,14 @@ namespace Nz::ShaderAst
Expression& MandatoryExpr(ExpressionPtr& node);
Statement& MandatoryStatement(StatementPtr& node);
void TypeMustMatch(ExpressionPtr& left, ExpressionPtr& right);
void TypeMustMatch(const ShaderExpressionType& left, const ShaderExpressionType& right);
void TypeMustMatch(const ExpressionType& left, const ExpressionType& right);
ShaderExpressionType CheckField(const std::string& structName, const std::string* memberIdentifier, std::size_t remainingMembers);
ExpressionType CheckField(const std::string& structName, const std::string* memberIdentifier, std::size_t remainingMembers);
AstCache::Scope& EnterScope();
void ExitScope();
void RegisterExpressionType(Expression& node, ShaderExpressionType expressionType);
void RegisterExpressionType(Expression& node, ExpressionType expressionType);
void RegisterScope(Node& node);
void Visit(AccessMemberExpression& node) override;
@ -51,6 +51,7 @@ namespace Nz::ShaderAst
void Visit(BranchStatement& node) override;
void Visit(ConditionalStatement& node) override;
void Visit(DeclareExternalStatement& node) override;
void Visit(DeclareFunctionStatement& node) override;
void Visit(DeclareStructStatement& node) override;
void Visit(DeclareVariableStatement& node) override;

View File

@ -15,6 +15,11 @@ namespace Nz::ShaderBuilder
{
namespace Impl
{
struct Assign
{
inline std::unique_ptr<ShaderAst::AssignExpression> operator()(ShaderAst::AssignType op, ShaderAst::ExpressionPtr left, ShaderAst::ExpressionPtr right) const;
};
struct Binary
{
inline std::unique_ptr<ShaderAst::BinaryExpression> operator()(ShaderAst::BinaryType op, ShaderAst::ExpressionPtr left, ShaderAst::ExpressionPtr right) const;
@ -26,6 +31,11 @@ namespace Nz::ShaderBuilder
inline std::unique_ptr<ShaderAst::BranchStatement> operator()(std::vector<ShaderAst::BranchStatement::ConditionalStatement> condStatements, ShaderAst::StatementPtr elseStatement = nullptr) const;
};
struct Cast
{
inline std::unique_ptr<ShaderAst::CastExpression> operator()(ShaderAst::ExpressionType targetType, std::vector<ShaderAst::ExpressionPtr> expressions) const;
};
struct Constant
{
inline std::unique_ptr<ShaderAst::ConstantExpression> operator()(ShaderConstantValue value) const;
@ -33,13 +43,24 @@ namespace Nz::ShaderBuilder
struct DeclareFunction
{
inline std::unique_ptr<ShaderAst::DeclareFunctionStatement> operator()(std::string name, std::vector<ShaderAst::DeclareFunctionStatement::Parameter> parameters, std::vector<ShaderAst::StatementPtr> statements, ShaderAst::ShaderExpressionType returnType = ShaderAst::BasicType::Void) const;
inline std::unique_ptr<ShaderAst::DeclareFunctionStatement> operator()(std::vector<ShaderAst::Attribute> attributes, std::string name, std::vector<ShaderAst::DeclareFunctionStatement::Parameter> parameters, std::vector<ShaderAst::StatementPtr> statements, ShaderAst::ShaderExpressionType returnType = ShaderAst::BasicType::Void) const;
inline std::unique_ptr<ShaderAst::DeclareFunctionStatement> operator()(std::string name, std::vector<ShaderAst::DeclareFunctionStatement::Parameter> parameters, std::vector<ShaderAst::StatementPtr> statements, ShaderAst::ExpressionType returnType = ShaderAst::NoType{}) const;
inline std::unique_ptr<ShaderAst::DeclareFunctionStatement> operator()(std::vector<ShaderAst::Attribute> attributes, std::string name, std::vector<ShaderAst::DeclareFunctionStatement::Parameter> parameters, std::vector<ShaderAst::StatementPtr> statements, ShaderAst::ExpressionType returnType = ShaderAst::NoType{}) const;
};
struct DeclareStruct
{
inline std::unique_ptr<ShaderAst::DeclareStructStatement> operator()(ShaderAst::StructDescription description) const;
inline std::unique_ptr<ShaderAst::DeclareStructStatement> operator()(std::vector<ShaderAst::Attribute> attributes, ShaderAst::StructDescription description) const;
};
struct DeclareVariable
{
inline std::unique_ptr<ShaderAst::DeclareVariableStatement> operator()(std::string name, ShaderAst::ShaderExpressionType type, ShaderAst::ExpressionPtr initialValue = nullptr) const;
inline std::unique_ptr<ShaderAst::DeclareVariableStatement> operator()(std::string name, ShaderAst::ExpressionType type, ShaderAst::ExpressionPtr initialValue = nullptr) const;
};
struct ExpressionStatement
{
inline std::unique_ptr<ShaderAst::ExpressionStatement> operator()(ShaderAst::ExpressionPtr expression) const;
};
struct Identifier
@ -47,9 +68,9 @@ namespace Nz::ShaderBuilder
inline std::unique_ptr<ShaderAst::IdentifierExpression> operator()(std::string name) const;
};
struct Return
struct Intrinsic
{
inline std::unique_ptr<ShaderAst::ReturnStatement> operator()(ShaderAst::ExpressionPtr expr = nullptr) const;
inline std::unique_ptr<ShaderAst::IntrinsicExpression> operator()(ShaderAst::IntrinsicType intrinsicType, std::vector<ShaderAst::ExpressionPtr> parameters) const;
};
template<typename T>
@ -57,15 +78,25 @@ namespace Nz::ShaderBuilder
{
std::unique_ptr<T> operator()() const;
};
struct Return
{
inline std::unique_ptr<ShaderAst::ReturnStatement> operator()(ShaderAst::ExpressionPtr expr = nullptr) const;
};
}
constexpr Impl::Assign Assign;
constexpr Impl::Binary Binary;
constexpr Impl::Branch Branch;
constexpr Impl::Cast Cast;
constexpr Impl::Constant Constant;
constexpr Impl::DeclareFunction DeclareFunction;
constexpr Impl::DeclareStruct DeclareStruct;
constexpr Impl::DeclareVariable DeclareVariable;
constexpr Impl::ExpressionStatement ExpressionStatement;
constexpr Impl::NoParam<ShaderAst::DiscardStatement> Discard;
constexpr Impl::Identifier Identifier;
constexpr Impl::Intrinsic Intrinsic;
constexpr Impl::NoParam<ShaderAst::NoOpStatement> NoOp;
constexpr Impl::Return Return;
}

View File

@ -7,14 +7,24 @@
namespace Nz::ShaderBuilder
{
inline std::unique_ptr<ShaderAst::AssignExpression> Impl::Assign::operator()(ShaderAst::AssignType op, ShaderAst::ExpressionPtr left, ShaderAst::ExpressionPtr right) const
{
auto assignNode = std::make_unique<ShaderAst::AssignExpression>();
assignNode->op = op;
assignNode->left = std::move(left);
assignNode->right = std::move(right);
return assignNode;
}
inline std::unique_ptr<ShaderAst::BinaryExpression> Impl::Binary::operator()(ShaderAst::BinaryType op, ShaderAst::ExpressionPtr left, ShaderAst::ExpressionPtr right) const
{
auto constantNode = std::make_unique<ShaderAst::BinaryExpression>();
constantNode->op = op;
constantNode->left = std::move(left);
constantNode->right = std::move(right);
auto binaryNode = std::make_unique<ShaderAst::BinaryExpression>();
binaryNode->op = op;
binaryNode->left = std::move(left);
binaryNode->right = std::move(right);
return constantNode;
return binaryNode;
}
inline std::unique_ptr<ShaderAst::BranchStatement> Impl::Branch::operator()(ShaderAst::ExpressionPtr condition, ShaderAst::StatementPtr truePath, ShaderAst::StatementPtr falsePath) const
@ -39,6 +49,18 @@ namespace Nz::ShaderBuilder
return branchNode;
}
inline std::unique_ptr<ShaderAst::CastExpression> Impl::Cast::operator()(ShaderAst::ExpressionType targetType, std::vector<ShaderAst::ExpressionPtr> expressions) const
{
auto castNode = std::make_unique<ShaderAst::CastExpression>();
castNode->targetType = std::move(targetType);
assert(expressions.size() <= castNode->expressions.size());
for (std::size_t i = 0; i < expressions.size(); ++i)
castNode->expressions[i] = std::move(expressions[i]);
return castNode;
}
inline std::unique_ptr<ShaderAst::ConstantExpression> Impl::Constant::operator()(ShaderConstantValue value) const
{
auto constantNode = std::make_unique<ShaderAst::ConstantExpression>();
@ -47,7 +69,7 @@ namespace Nz::ShaderBuilder
return constantNode;
}
inline std::unique_ptr<ShaderAst::DeclareFunctionStatement> Impl::DeclareFunction::operator()(std::string name, std::vector<ShaderAst::DeclareFunctionStatement::Parameter> parameters, std::vector<ShaderAst::StatementPtr> statements, ShaderAst::ShaderExpressionType returnType) const
inline std::unique_ptr<ShaderAst::DeclareFunctionStatement> Impl::DeclareFunction::operator()(std::string name, std::vector<ShaderAst::DeclareFunctionStatement::Parameter> parameters, std::vector<ShaderAst::StatementPtr> statements, ShaderAst::ExpressionType returnType) const
{
auto declareFunctionNode = std::make_unique<ShaderAst::DeclareFunctionStatement>();
declareFunctionNode->name = std::move(name);
@ -58,7 +80,7 @@ namespace Nz::ShaderBuilder
return declareFunctionNode;
}
inline std::unique_ptr<ShaderAst::DeclareFunctionStatement> Impl::DeclareFunction::operator()(std::vector<ShaderAst::Attribute> attributes, std::string name, std::vector<ShaderAst::DeclareFunctionStatement::Parameter> parameters, std::vector<ShaderAst::StatementPtr> statements, ShaderAst::ShaderExpressionType returnType) const
inline std::unique_ptr<ShaderAst::DeclareFunctionStatement> Impl::DeclareFunction::operator()(std::vector<ShaderAst::Attribute> attributes, std::string name, std::vector<ShaderAst::DeclareFunctionStatement::Parameter> parameters, std::vector<ShaderAst::StatementPtr> statements, ShaderAst::ExpressionType returnType) const
{
auto declareFunctionNode = std::make_unique<ShaderAst::DeclareFunctionStatement>();
declareFunctionNode->attributes = std::move(attributes);
@ -70,7 +92,24 @@ namespace Nz::ShaderBuilder
return declareFunctionNode;
}
inline std::unique_ptr<ShaderAst::DeclareVariableStatement> Nz::ShaderBuilder::Impl::DeclareVariable::operator()(std::string name, ShaderAst::ShaderExpressionType type, ShaderAst::ExpressionPtr initialValue) const
inline std::unique_ptr<ShaderAst::DeclareStructStatement> Impl::DeclareStruct::operator()(ShaderAst::StructDescription description) const
{
auto declareStructNode = std::make_unique<ShaderAst::DeclareStructStatement>();
declareStructNode->description = std::move(description);
return declareStructNode;
}
inline std::unique_ptr<ShaderAst::DeclareStructStatement> Impl::DeclareStruct::operator()(std::vector<ShaderAst::Attribute> attributes, ShaderAst::StructDescription description) const
{
auto declareStructNode = std::make_unique<ShaderAst::DeclareStructStatement>();
declareStructNode->attributes = std::move(attributes);
declareStructNode->description = std::move(description);
return declareStructNode;
}
inline std::unique_ptr<ShaderAst::DeclareVariableStatement> Nz::ShaderBuilder::Impl::DeclareVariable::operator()(std::string name, ShaderAst::ExpressionType type, ShaderAst::ExpressionPtr initialValue) const
{
auto declareVariableNode = std::make_unique<ShaderAst::DeclareVariableStatement>();
declareVariableNode->varName = std::move(name);
@ -80,6 +119,14 @@ namespace Nz::ShaderBuilder
return declareVariableNode;
}
inline std::unique_ptr<ShaderAst::ExpressionStatement> Impl::ExpressionStatement::operator()(ShaderAst::ExpressionPtr expression) const
{
auto expressionStatementNode = std::make_unique<ShaderAst::ExpressionStatement>();
expressionStatementNode->expression = std::move(expression);
return expressionStatementNode;
}
inline std::unique_ptr<ShaderAst::IdentifierExpression> Impl::Identifier::operator()(std::string name) const
{
auto identifierNode = std::make_unique<ShaderAst::IdentifierExpression>();
@ -88,6 +135,15 @@ namespace Nz::ShaderBuilder
return identifierNode;
}
inline std::unique_ptr<ShaderAst::IntrinsicExpression> Impl::Intrinsic::operator()(ShaderAst::IntrinsicType intrinsicType, std::vector<ShaderAst::ExpressionPtr> parameters) const
{
auto intrinsicExpression = std::make_unique<ShaderAst::IntrinsicExpression>();
intrinsicExpression->intrinsic = intrinsicType;
intrinsicExpression->parameters = std::move(parameters);
return intrinsicExpression;
}
inline std::unique_ptr<ShaderAst::ReturnStatement> Impl::Return::operator()(ShaderAst::ExpressionPtr expr) const
{
auto returnNode = std::make_unique<ShaderAst::ReturnStatement>();

View File

@ -18,28 +18,19 @@ namespace Nz::ShaderAst
enum class AttributeType
{
Entry, //< Entry point (function only) - has argument type
Layout //< Struct layout (struct only) - has argument style
Binding, //< Binding (external var only) - has argument index
Builtin, //< Builtin (struct member only) - has argument type
Entry, //< Entry point (function only) - has argument type
Layout, //< Struct layout (struct only) - has argument style
Location //< Location (struct member only) - has argument index
};
enum class BasicType
enum class PrimitiveType
{
Boolean, //< bool
Float1, //< float
Float2, //< vec2
Float3, //< vec3
Float4, //< vec4
Int1, //< int
Int2, //< ivec2
Int3, //< ivec3
Int4, //< ivec4
Mat4x4, //< mat4
Sampler2D, //< sampler2D
Void, //< void
UInt1, //< uint
UInt2, //< uvec2
UInt3, //< uvec3
UInt4 //< uvec4
Boolean, //< bool
Float32, //< f32
Int32, //< i32
UInt32, //< ui32
};
enum class BinaryType
@ -71,7 +62,8 @@ namespace Nz::ShaderAst
enum class IntrinsicType
{
CrossProduct,
DotProduct
DotProduct,
SampleTexture
};
enum class MemoryLayout
@ -107,9 +99,6 @@ namespace Nz::ShaderAst
ParameterVariable,
UniformVariable
};
inline std::size_t GetComponentCount(BasicType type);
inline BasicType GetComponentType(BasicType type);
}
#include <Nazara/Shader/ShaderEnums.inl>

View File

@ -7,51 +7,6 @@
namespace Nz::ShaderAst
{
inline std::size_t GetComponentCount(BasicType type)
{
switch (type)
{
case BasicType::Float2:
case BasicType::Int2:
return 2;
case BasicType::Float3:
case BasicType::Int3:
return 3;
case BasicType::Float4:
case BasicType::Int4:
return 4;
case BasicType::Mat4x4:
return 4;
default:
return 1;
}
}
inline BasicType GetComponentType(BasicType type)
{
switch (type)
{
case BasicType::Float2:
case BasicType::Float3:
case BasicType::Float4:
return BasicType::Float1;
case BasicType::Int2:
case BasicType::Int3:
case BasicType::Int4:
return BasicType::Int1;
case BasicType::Mat4x4:
return BasicType::Float4;
default:
return type;
}
}
}
#include <Nazara/Shader/DebugOff.hpp>

View File

@ -19,6 +19,12 @@ namespace Nz::ShaderLang
public:
using exception::exception;
};
class DuplicateIdentifier : public std::exception
{
public:
using exception::exception;
};
class ReservedKeyword : public std::exception
{
@ -56,17 +62,24 @@ namespace Nz::ShaderLang
// Flow control
const Token& Advance();
void Consume(std::size_t count = 1);
ShaderAst::ExpressionType DecodeType(const std::string& identifier);
void EnterScope();
const Token& Expect(const Token& token, TokenType type);
const Token& ExpectNot(const Token& token, TokenType type);
const Token& Expect(TokenType type);
void LeaveScope();
bool IsVariableInScope(const std::string_view& identifier) const;
void RegisterVariable(std::string identifier);
const Token& Peek(std::size_t advance = 0);
void HandleAttributes();
std::vector<ShaderAst::Attribute> ParseAttributes();
// Statements
ShaderAst::StatementPtr ParseExternalBlock(std::vector<ShaderAst::Attribute> attributes = {});
std::vector<ShaderAst::StatementPtr> ParseFunctionBody();
ShaderAst::StatementPtr ParseFunctionDeclaration(std::vector<ShaderAst::Attribute> attributes = {});
ShaderAst::DeclareFunctionStatement::Parameter ParseFunctionParameter();
ShaderAst::StatementPtr ParseStructDeclaration(std::vector<ShaderAst::Attribute> attributes = {});
ShaderAst::StatementPtr ParseReturnStatement();
ShaderAst::StatementPtr ParseStatement();
std::vector<ShaderAst::StatementPtr> ParseStatementList();
@ -75,22 +88,28 @@ namespace Nz::ShaderLang
// Expressions
ShaderAst::ExpressionPtr ParseBinOpRhs(int exprPrecedence, ShaderAst::ExpressionPtr lhs);
ShaderAst::ExpressionPtr ParseExpression();
ShaderAst::ExpressionPtr ParseFloatingPointExpression(bool minus = false);
ShaderAst::ExpressionPtr ParseIdentifier();
ShaderAst::ExpressionPtr ParseIntegerExpression();
ShaderAst::ExpressionPtr ParseIntegerExpression(bool minus = false);
std::vector<ShaderAst::ExpressionPtr> ParseParameters();
ShaderAst::ExpressionPtr ParseParenthesisExpression();
ShaderAst::ExpressionPtr ParsePrimaryExpression();
ShaderAst::ExpressionPtr ParseVariableAssignation();
ShaderAst::AttributeType ParseIdentifierAsAttributeType();
const std::string& ParseIdentifierAsName();
ShaderAst::ShaderExpressionType ParseIdentifierAsType();
ShaderAst::PrimitiveType ParsePrimitiveType();
ShaderAst::ExpressionType ParseType();
static int GetTokenPrecedence(TokenType token);
struct Context
{
std::unique_ptr<ShaderAst::MultiStatement> root;
std::size_t tokenCount;
std::size_t tokenIndex = 0;
std::vector<std::size_t> scopeSizes;
std::vector<std::string> identifiersInScope;
std::unique_ptr<ShaderAst::MultiStatement> root;
const Token* tokens;
};

View File

@ -21,15 +21,22 @@ NAZARA_SHADERLANG_TOKEN(Colon)
NAZARA_SHADERLANG_TOKEN(Comma)
NAZARA_SHADERLANG_TOKEN(Divide)
NAZARA_SHADERLANG_TOKEN(Dot)
NAZARA_SHADERLANG_TOKEN(Equal)
NAZARA_SHADERLANG_TOKEN(External)
NAZARA_SHADERLANG_TOKEN(FloatingPointValue)
NAZARA_SHADERLANG_TOKEN(EndOfStream)
NAZARA_SHADERLANG_TOKEN(FunctionDeclaration)
NAZARA_SHADERLANG_TOKEN(FunctionReturn)
NAZARA_SHADERLANG_TOKEN(GreatherThan)
NAZARA_SHADERLANG_TOKEN(GreatherThanEqual)
NAZARA_SHADERLANG_TOKEN(IntegerValue)
NAZARA_SHADERLANG_TOKEN(Identifier)
NAZARA_SHADERLANG_TOKEN(LessThan)
NAZARA_SHADERLANG_TOKEN(LessThanEqual)
NAZARA_SHADERLANG_TOKEN(Let)
NAZARA_SHADERLANG_TOKEN(Multiply)
NAZARA_SHADERLANG_TOKEN(Minus)
NAZARA_SHADERLANG_TOKEN(NotEqual)
NAZARA_SHADERLANG_TOKEN(Plus)
NAZARA_SHADERLANG_TOKEN(OpenAttribute)
NAZARA_SHADERLANG_TOKEN(OpenCurlyBracket)
@ -37,6 +44,7 @@ NAZARA_SHADERLANG_TOKEN(OpenSquareBracket)
NAZARA_SHADERLANG_TOKEN(OpenParenthesis)
NAZARA_SHADERLANG_TOKEN(Semicolon)
NAZARA_SHADERLANG_TOKEN(Return)
NAZARA_SHADERLANG_TOKEN(Struct)
#undef NAZARA_SHADERLANG_TOKEN
#undef NAZARA_SHADERLANG_TOKEN_LAST

View File

@ -12,9 +12,10 @@
#include <Nazara/Math/Vector3.hpp>
#include <Nazara/Math/Vector4.hpp>
#include <Nazara/Shader/Config.hpp>
#include <Nazara/Shader/ShaderAstTypes.hpp>
#include <Nazara/Shader/ShaderConstantValue.hpp>
#include <Nazara/Shader/ShaderEnums.hpp>
#include <Nazara/Shader/Ast/Attribute.hpp>
#include <Nazara/Shader/Ast/ExpressionType.hpp>
#include <array>
#include <memory>
#include <optional>
@ -25,12 +26,6 @@ namespace Nz::ShaderAst
class AstExpressionVisitor;
class AstStatementVisitor;
struct Attribute
{
AttributeType type;
std::string args;
};
struct NAZARA_SHADER_API Node
{
Node() = default;
@ -97,7 +92,7 @@ namespace Nz::ShaderAst
NodeType GetType() const override;
void Visit(AstExpressionVisitor& visitor) override;
BasicType targetType;
ExpressionType targetType;
std::array<ExpressionPtr, 4> expressions;
};
@ -189,6 +184,22 @@ namespace Nz::ShaderAst
StatementPtr statement;
};
struct NAZARA_SHADER_API DeclareExternalStatement : Statement
{
NodeType GetType() const override;
void Visit(AstStatementVisitor& visitor) override;
struct ExternalVar
{
std::vector<Attribute> attributes;
std::string name;
ExpressionType type;
};
std::vector<Attribute> attributes;
std::vector<ExternalVar> externalVars;
};
struct NAZARA_SHADER_API DeclareFunctionStatement : Statement
{
NodeType GetType() const override;
@ -197,14 +208,14 @@ namespace Nz::ShaderAst
struct Parameter
{
std::string name;
ShaderExpressionType type;
ExpressionType type;
};
std::string name;
std::vector<Attribute> attributes;
std::vector<Parameter> parameters;
std::vector<StatementPtr> statements;
ShaderExpressionType returnType = BasicType::Void;
ExpressionType returnType;
};
struct NAZARA_SHADER_API DeclareStructStatement : Statement
@ -212,6 +223,7 @@ namespace Nz::ShaderAst
NodeType GetType() const override;
void Visit(AstStatementVisitor& visitor) override;
std::vector<Attribute> attributes;
StructDescription description;
};
@ -222,7 +234,7 @@ namespace Nz::ShaderAst
std::string varName;
ExpressionPtr initialExpression;
ShaderExpressionType varType;
ExpressionType varType;
};
struct NAZARA_SHADER_API DiscardStatement : Statement

View File

@ -10,7 +10,7 @@
#include <Nazara/Prerequisites.hpp>
#include <Nazara/Shader/ShaderConstantValue.hpp>
#include <Nazara/Shader/ShaderEnums.hpp>
#include <Nazara/Shader/ShaderAstTypes.hpp>
#include <Nazara/Shader/Ast/ExpressionType.hpp>
#include <Nazara/Shader/SpirvData.hpp>
#include <memory>
#include <optional>
@ -172,11 +172,16 @@ namespace Nz
SpirvConstantCache& operator=(SpirvConstantCache&& cache) noexcept;
static ConstantPtr BuildConstant(const ShaderConstantValue& value);
static TypePtr BuildFunctionType(const ShaderAst::ShaderExpressionType& retType, const std::vector<ShaderAst::ShaderExpressionType>& parameters);
static TypePtr BuildPointerType(const ShaderAst::BasicType& type, SpirvStorageClass storageClass);
static TypePtr BuildPointerType(const ShaderAst::ShaderExpressionType& type, SpirvStorageClass storageClass);
static TypePtr BuildType(const ShaderAst::BasicType& type);
static TypePtr BuildType(const ShaderAst::ShaderExpressionType& type);
static TypePtr BuildFunctionType(const ShaderAst::ExpressionType& retType, const std::vector<ShaderAst::ExpressionType>& parameters);
static TypePtr BuildPointerType(const ShaderAst::PrimitiveType& type, SpirvStorageClass storageClass);
static TypePtr BuildPointerType(const ShaderAst::ExpressionType& type, SpirvStorageClass storageClass);
static TypePtr BuildType(const ShaderAst::ExpressionType& type);
static TypePtr BuildType(const ShaderAst::IdentifierType& type);
static TypePtr BuildType(const ShaderAst::MatrixType& type);
static TypePtr BuildType(const ShaderAst::NoType& type);
static TypePtr BuildType(const ShaderAst::PrimitiveType& type);
static TypePtr BuildType(const ShaderAst::SamplerType& type);
static TypePtr BuildType(const ShaderAst::VectorType& type);
private:
struct DepRegisterer;

View File

@ -62,8 +62,8 @@ namespace Nz
const ExtVar& GetInputVariable(const std::string& name) const;
const ExtVar& GetOutputVariable(const std::string& name) const;
const ExtVar& GetUniformVariable(const std::string& name) const;
UInt32 GetPointerTypeId(const ShaderAst::ShaderExpressionType& type, SpirvStorageClass storageClass) const;
UInt32 GetTypeId(const ShaderAst::ShaderExpressionType& type) const;
UInt32 GetPointerTypeId(const ShaderAst::ExpressionType& type, SpirvStorageClass storageClass) const;
UInt32 GetTypeId(const ShaderAst::ExpressionType& type) const;
inline bool IsConditionEnabled(const std::string& condition) const;
@ -80,8 +80,8 @@ namespace Nz
UInt32 RegisterConstant(const ShaderConstantValue& value);
UInt32 RegisterFunctionType(const ShaderAst::DeclareFunctionStatement& functionNode);
UInt32 RegisterPointerType(ShaderAst::ShaderExpressionType type, SpirvStorageClass storageClass);
UInt32 RegisterType(ShaderAst::ShaderExpressionType type);
UInt32 RegisterPointerType(ShaderAst::ExpressionType type, SpirvStorageClass storageClass);
UInt32 RegisterType(ShaderAst::ExpressionType type);
void WriteLocalVariable(std::string name, UInt32 resultId);
@ -106,7 +106,7 @@ namespace Nz
struct FunctionParameter
{
std::string name;
ShaderAst::ShaderExpressionType type;
ShaderAst::ExpressionType type;
};
struct State;

View File

@ -20,37 +20,43 @@ namespace Nz
namespace
{
static const char* flipYUniformName = "_NzFlipValue";
static const char* overridenMain = "_NzMain";
struct AstAdapter : ShaderAst::AstCloner
{
void Visit(ShaderAst::AssignExpression& node) override
using AstCloner::Clone;
std::unique_ptr<ShaderAst::DeclareFunctionStatement> Clone(ShaderAst::DeclareFunctionStatement& node) override
{
if (!flipYPosition)
return AstCloner::Visit(node);
auto clone = AstCloner::Clone(node);
if (clone->name == "main")
clone->name = "_NzMain";
if (node.left->GetType() != ShaderAst::NodeType::IdentifierExpression)
return AstCloner::Visit(node);
/*
FIXME:
const auto& identifier = static_cast<const ShaderAst::Identifier&>(*node.left);
if (identifier.var->GetType() != ShaderAst::VariableType::BuiltinVariable)
return ShaderAstCloner::Visit(node);
const auto& builtinVar = static_cast<const ShaderAst::BuiltinVariable&>(*identifier.var);
if (builtinVar.entry != ShaderAst::BuiltinEntry::VertexPosition)
return ShaderAstCloner::Visit(node);
auto flipVar = ShaderBuilder::Uniform(flipYUniformName, ShaderAst::BasicType::Float1);
auto oneConstant = ShaderBuilder::Constant(1.f);
auto fixYValue = ShaderBuilder::Cast<ShaderAst::BasicType::Float4>(oneConstant, ShaderBuilder::Identifier(flipVar), oneConstant, oneConstant);
auto mulFix = ShaderBuilder::Multiply(CloneExpression(node.right), fixYValue);
PushExpression(ShaderAst::AssignOp::Build(node.op, CloneExpression(node.left), mulFix));*/
return clone;
}
bool flipYPosition = false;
void Visit(ShaderAst::DeclareFunctionStatement& node)
{
if (removedEntryPoints.find(&node) != removedEntryPoints.end())
{
PushStatement(ShaderBuilder::NoOp());
return;
}
AstCloner::Visit(node);
}
std::unordered_set<ShaderAst::DeclareFunctionStatement*> removedEntryPoints;
};
struct Builtin
{
std::string identifier;
ShaderStageTypeFlags stageFlags;
};
std::unordered_map<std::string, Builtin> builtinMapping = {
{ "position", { "gl_Position", ShaderStageType::Vertex } }
};
}
@ -59,6 +65,7 @@ namespace Nz
{
const States* states = nullptr;
ShaderAst::AstCache cache;
ShaderAst::DeclareFunctionStatement* entryFunc = nullptr;
std::stringstream stream;
unsigned int indentLevel = 0;
};
@ -69,19 +76,8 @@ namespace Nz
{
}
std::string GlslWriter::Generate(ShaderAst::StatementPtr& shader, const States& conditions)
std::string GlslWriter::Generate(ShaderStageType shaderStage, ShaderAst::StatementPtr& shader, const States& conditions)
{
/*const ShaderAst* selectedShader = &inputShader;
std::optional<ShaderAst> modifiedShader;
if (inputShader.GetStage() == ShaderStageType::Vertex && m_environment.flipYPosition)
{
modifiedShader.emplace(inputShader);
modifiedShader->AddUniform(flipYUniformName, ShaderAst::BasicType::Float1);
selectedShader = &modifiedShader.value();
}*/
State state;
m_currentState = &state;
CallOnExit onExit([this]()
@ -93,6 +89,27 @@ namespace Nz
if (!ShaderAst::ValidateAst(shader, &error, &state.cache))
throw std::runtime_error("Invalid shader AST: " + error);
state.entryFunc = state.cache.entryFunctions[UnderlyingCast(shaderStage)];
if (!state.entryFunc)
throw std::runtime_error("missing entry point");
AstAdapter adapter;
for (ShaderAst::DeclareFunctionStatement* entryFunc : state.cache.entryFunctions)
{
if (entryFunc != state.entryFunc)
adapter.removedEntryPoints.insert(entryFunc);
}
ShaderAst::StatementPtr adaptedShader = adapter.Clone(shader);
state.cache.Clear();
if (!ShaderAst::ValidateAst(adaptedShader, &error, &state.cache))
throw std::runtime_error("Internal error:" + error);
state.entryFunc = state.cache.entryFunctions[UnderlyingCast(shaderStage)];
assert(state.entryFunc);
unsigned int glslVersion;
if (m_environment.glES)
{
@ -141,14 +158,14 @@ namespace Nz
if (!m_environment.glES && m_environment.extCallback)
{
// GL_ARB_shading_language_420pack (required for layout(binding = X))
if (glslVersion < 420 && HasExplicitBinding(shader))
if (glslVersion < 420 && HasExplicitBinding(adaptedShader))
{
if (m_environment.extCallback("GL_ARB_shading_language_420pack"))
requiredExtensions.emplace_back("GL_ARB_shading_language_420pack");
}
// GL_ARB_separate_shader_objects (required for layout(location = X))
if (glslVersion < 410 && HasExplicitLocation(shader))
if (glslVersion < 410 && HasExplicitLocation(adaptedShader))
{
if (m_environment.extCallback("GL_ARB_separate_shader_objects"))
requiredExtensions.emplace_back("GL_ARB_separate_shader_objects");
@ -173,7 +190,10 @@ namespace Nz
AppendLine();
}
shader->Visit(*this);
adaptedShader->Visit(*this);
// Append true GLSL entry point
AppendEntryPoint(shaderStage);
return state.stream.str();
}
@ -188,7 +208,7 @@ namespace Nz
return flipYUniformName;
}
void GlslWriter::Append(ShaderAst::ShaderExpressionType type)
void GlslWriter::Append(const ShaderAst::ExpressionType& type)
{
std::visit([&](auto&& arg)
{
@ -206,29 +226,82 @@ namespace Nz
}
}
void GlslWriter::Append(ShaderAst::BasicType type)
void GlslWriter::Append(const ShaderAst::IdentifierType& identifierType)
{
Append(identifierType.name);
}
void GlslWriter::Append(const ShaderAst::MatrixType& matrixType)
{
if (matrixType.columnCount == matrixType.rowCount)
{
Append("mat");
Append(matrixType.columnCount);
}
else
{
Append("mat");
Append(matrixType.columnCount);
Append("x");
Append(matrixType.rowCount);
}
}
void GlslWriter::Append(ShaderAst::PrimitiveType type)
{
switch (type)
{
case ShaderAst::BasicType::Boolean: return Append("bool");
case ShaderAst::BasicType::Float1: return Append("float");
case ShaderAst::BasicType::Float2: return Append("vec2");
case ShaderAst::BasicType::Float3: return Append("vec3");
case ShaderAst::BasicType::Float4: return Append("vec4");
case ShaderAst::BasicType::Int1: return Append("int");
case ShaderAst::BasicType::Int2: return Append("ivec2");
case ShaderAst::BasicType::Int3: return Append("ivec3");
case ShaderAst::BasicType::Int4: return Append("ivec4");
case ShaderAst::BasicType::Mat4x4: return Append("mat4");
case ShaderAst::BasicType::Sampler2D: return Append("sampler2D");
case ShaderAst::BasicType::UInt1: return Append("uint");
case ShaderAst::BasicType::UInt2: return Append("uvec2");
case ShaderAst::BasicType::UInt3: return Append("uvec3");
case ShaderAst::BasicType::UInt4: return Append("uvec4");
case ShaderAst::BasicType::Void: return Append("void");
case ShaderAst::PrimitiveType::Boolean: return Append("bool");
case ShaderAst::PrimitiveType::Float32: return Append("float");
case ShaderAst::PrimitiveType::Int32: return Append("ivec2");
case ShaderAst::PrimitiveType::UInt32: return Append("uint");
}
}
void GlslWriter::Append(const ShaderAst::SamplerType& samplerType)
{
switch (samplerType.sampledType)
{
case ShaderAst::PrimitiveType::Boolean:
case ShaderAst::PrimitiveType::Float32:
break;
case ShaderAst::PrimitiveType::Int32: Append("i"); break;
case ShaderAst::PrimitiveType::UInt32: Append("u"); break;
}
Append("sampler");
switch (samplerType.dim)
{
case ImageType_1D: Append("1D"); break;
case ImageType_1D_Array: Append("1DArray"); break;
case ImageType_2D: Append("2D"); break;
case ImageType_2D_Array: Append("2DArray"); break;
case ImageType_3D: Append("3D"); break;
case ImageType_Cubemap: Append("Cube"); break;
}
}
void GlslWriter::Append(const ShaderAst::UniformType& uniformType)
{
/* TODO */
}
void GlslWriter::Append(const ShaderAst::VectorType& vecType)
{
switch (vecType.type)
{
case ShaderAst::PrimitiveType::Boolean: Append("b"); break;
case ShaderAst::PrimitiveType::Float32: break;
case ShaderAst::PrimitiveType::Int32: Append("i"); break;
case ShaderAst::PrimitiveType::UInt32: Append("u"); break;
}
Append("vec");
Append(vecType.componentCount);
}
void GlslWriter::Append(ShaderAst::MemoryLayout layout)
{
switch (layout)
@ -239,6 +312,11 @@ namespace Nz
}
}
void GlslWriter::Append(ShaderAst::NoType)
{
return Append("void");
}
template<typename T>
void GlslWriter::Append(const T& param)
{
@ -246,6 +324,12 @@ namespace Nz
m_currentState->stream << param;
}
template<typename T1, typename T2, typename... Args>
void GlslWriter::Append(const T1& firstParam, const T2& secondParam, Args&&... params)
{
Append(firstParam);
Append(secondParam, std::forward<Args>(params)...);
}
void GlslWriter::AppendCommentSection(const std::string& section)
{
@ -256,6 +340,152 @@ namespace Nz
AppendLine();
}
void GlslWriter::AppendEntryPoint(ShaderStageType shaderStage)
{
AppendLine();
AppendLine("// Entry point handling");
struct InOutField
{
std::string name;
std::string targetName;
};
std::vector<InOutField> inputFields;
const ShaderAst::StructDescription* inputStruct = nullptr;
auto HandleInOutStructs = [this, shaderStage](const ShaderAst::ExpressionType& expressionType, std::vector<InOutField>& fields, const char* keyword, const char* fromPrefix, const char* targetPrefix) -> const ShaderAst::StructDescription*
{
assert(IsIdentifierType(expressionType));
const ShaderAst::AstCache::Identifier* identifier = m_currentState->cache.FindIdentifier(0, std::get<ShaderAst::IdentifierType>(expressionType).name);
assert(identifier);
assert(std::holds_alternative<ShaderAst::StructDescription>(identifier->value));
const auto& s = std::get<ShaderAst::StructDescription>(identifier->value);
for (const auto& member : s.members)
{
bool skip = false;
std::optional<std::string> builtinName;
std::optional<long long> attributeLocation;
for (const auto& [attributeType, attributeParam] : member.attributes)
{
if (attributeType == ShaderAst::AttributeType::Builtin)
{
auto it = builtinMapping.find(std::get<std::string>(attributeParam));
if (it != builtinMapping.end())
{
const Builtin& builtin = it->second;
if (!builtin.stageFlags.Test(shaderStage))
{
skip = true;
break;
}
builtinName = builtin.identifier;
break;
}
}
else if (attributeType == ShaderAst::AttributeType::Location)
{
attributeLocation = std::get<long long>(attributeParam);
break;
}
}
if (!skip && attributeLocation)
{
Append("layout(location = ");
Append(*attributeLocation);
Append(") ");
Append(keyword);
Append(" ");
Append(member.type);
Append(" ");
Append(targetPrefix);
Append(member.name);
AppendLine(";");
fields.push_back({
fromPrefix + member.name,
targetPrefix + member.name
});
}
else if (builtinName)
{
fields.push_back({
fromPrefix + member.name,
*builtinName
});
}
}
AppendLine();
return &s;
};
if (!m_currentState->entryFunc->parameters.empty())
{
assert(m_currentState->entryFunc->parameters.size() == 1);
const auto& parameter = m_currentState->entryFunc->parameters.front();
inputStruct = HandleInOutStructs(parameter.type, inputFields, "in", "_nzInput.", "_NzIn_");
}
std::vector<InOutField> outputFields;
const ShaderAst::StructDescription* outputStruct = nullptr;
if (!IsNoType(m_currentState->entryFunc->returnType))
outputStruct = HandleInOutStructs(m_currentState->entryFunc->returnType, outputFields, "out", "_nzOutput.", "_NzOut_");
if (shaderStage == ShaderStageType::Vertex && m_environment.flipYPosition)
AppendLine("uniform float ", flipYUniformName, ";");
AppendLine("void main()");
EnterScope();
{
if (inputStruct)
{
Append(inputStruct->name);
AppendLine(" _nzInput;");
for (const auto& [name, targetName] : inputFields)
{
AppendLine(name, " = ", targetName, ";");
}
AppendLine();
}
if (outputStruct)
Append(outputStruct->name, " _nzOutput = ");
Append(m_currentState->entryFunc->name);
Append("(");
if (m_currentState->entryFunc)
Append("_nzInput");
Append(");");
if (outputStruct)
{
AppendLine();
for (const auto& [name, targetName] : outputFields)
{
bool isOutputPosition = (shaderStage == ShaderStageType::Vertex && m_environment.flipYPosition && targetName == "gl_Position");
AppendLine();
Append(targetName, " = ", name);
if (isOutputPosition)
Append(" * vec4(1.0, ", flipYUniformName, ", 1.0, 1.0)");
Append(";");
}
}
}
LeaveScope();
}
void GlslWriter::AppendField(std::size_t scopeId, const std::string& structName, const std::string* memberIdentifier, std::size_t remainingMembers)
{
Append(".");
@ -273,7 +503,7 @@ namespace Nz
const auto& member = *memberIt;
if (remainingMembers > 1)
AppendField(scopeId, std::get<std::string>(member.type), memberIdentifier + 1, remainingMembers - 1);
AppendField(scopeId, std::get<ShaderAst::IdentifierType>(member.type).name, memberIdentifier + 1, remainingMembers - 1);
}
void GlslWriter::AppendLine(const std::string& txt)
@ -283,6 +513,13 @@ namespace Nz
m_currentState->stream << txt << '\n' << std::string(m_currentState->indentLevel, '\t');
}
template<typename... Args>
void GlslWriter::AppendLine(Args&&... params)
{
(Append(std::forward<Args>(params)), ...);
AppendLine();
}
void GlslWriter::EnterScope()
{
NazaraAssert(m_currentState, "This function should only be called while processing an AST");
@ -291,13 +528,17 @@ namespace Nz
AppendLine("{");
}
void GlslWriter::LeaveScope()
void GlslWriter::LeaveScope(bool skipLine)
{
NazaraAssert(m_currentState, "This function should only be called while processing an AST");
m_currentState->indentLevel--;
AppendLine();
AppendLine("}");
if (skipLine)
AppendLine("}");
else
Append("}");
}
void GlslWriter::Visit(ShaderAst::ExpressionPtr& expr, bool encloseIfRequired)
@ -317,12 +558,12 @@ namespace Nz
{
Visit(node.structExpr, true);
const ShaderAst::ShaderExpressionType& exprType = GetExpressionType(*node.structExpr, &m_currentState->cache);
assert(IsStructType(exprType));
const ShaderAst::ExpressionType& exprType = GetExpressionType(*node.structExpr, &m_currentState->cache);
assert(IsIdentifierType(exprType));
std::size_t scopeId = m_currentState->cache.GetScopeId(&node);
AppendField(scopeId, std::get<std::string>(exprType), node.memberIdentifiers.data(), node.memberIdentifiers.size());
AppendField(scopeId, std::get<ShaderAst::IdentifierType>(exprType).name, node.memberIdentifiers.data(), node.memberIdentifiers.size());
}
void GlslWriter::Visit(ShaderAst::AssignExpression& node)
@ -336,7 +577,7 @@ namespace Nz
break;
}
node.left->Visit(*this);
node.right->Visit(*this);
}
void GlslWriter::Visit(ShaderAst::BranchStatement& node)
@ -455,6 +696,71 @@ namespace Nz
}, node.value);
}
void GlslWriter::Visit(ShaderAst::DeclareExternalStatement& node)
{
for (const auto& externalVar : node.externalVars)
{
std::optional<long long> bindingIndex;
bool isStd140 = false;
for (const auto& [attributeType, attributeParam] : externalVar.attributes)
{
if (attributeType == ShaderAst::AttributeType::Binding)
bindingIndex = std::get<long long>(attributeParam);
else if (attributeType == ShaderAst::AttributeType::Layout)
{
if (std::get<std::string>(attributeParam) == "std140")
isStd140 = true;
}
}
if (bindingIndex)
{
Append("layout(binding = ");
Append(*bindingIndex);
if (isStd140)
Append(", std140");
Append(") uniform ");
if (IsUniformType(externalVar.type))
{
Append("_NzBinding_");
AppendLine(externalVar.name);
EnterScope();
{
const ShaderAst::AstCache::Identifier* identifier = m_currentState->cache.FindIdentifier(0, std::get<ShaderAst::UniformType>(externalVar.type).containedType.name);
assert(identifier);
assert(std::holds_alternative<ShaderAst::StructDescription>(identifier->value));
const auto& s = std::get<ShaderAst::StructDescription>(identifier->value);
bool first = true;
for (const auto& [name, attribute, type] : s.members)
{
if (!first)
AppendLine();
first = false;
Append(type);
Append(" ");
Append(name);
Append(";");
}
}
LeaveScope(false);
}
else
Append(externalVar.type);
Append(" ");
Append(externalVar.name);
AppendLine(";");
}
}
}
void GlslWriter::Visit(ShaderAst::DeclareFunctionStatement& node)
{
NazaraAssert(m_currentState, "This function should only be called while processing an AST");
@ -475,15 +781,36 @@ namespace Nz
EnterScope();
{
AstAdapter adapter;
adapter.flipYPosition = m_environment.flipYPosition;
for (auto& statement : node.statements)
adapter.Clone(statement)->Visit(*this);
statement->Visit(*this);
}
LeaveScope();
}
void GlslWriter::Visit(ShaderAst::DeclareStructStatement& node)
{
Append("struct ");
AppendLine(node.description.name);
EnterScope();
{
bool first = true;
for (const auto& [name, attribute, type] : node.description.members)
{
if (!first)
AppendLine();
first = false;
Append(type);
Append(" ");
Append(name);
Append(";");
}
}
LeaveScope(false);
AppendLine(";");
}
void GlslWriter::Visit(ShaderAst::DeclareVariableStatement& node)
{
Append(node.varType);
@ -506,7 +833,7 @@ namespace Nz
void GlslWriter::Visit(ShaderAst::ExpressionStatement& node)
{
node.expression->Visit(*this);
Append(";");
AppendLine(";");
}
void GlslWriter::Visit(ShaderAst::IdentifierExpression& node)
@ -525,6 +852,10 @@ namespace Nz
case ShaderAst::IntrinsicType::DotProduct:
Append("dot");
break;
case ShaderAst::IntrinsicType::SampleTexture:
Append("texture");
break;
}
Append("(");
@ -624,4 +955,5 @@ namespace Nz
return false;
}
}

View File

@ -42,6 +42,21 @@ namespace Nz::ShaderAst
return PopStatement();
}
std::unique_ptr<DeclareFunctionStatement> AstCloner::Clone(DeclareFunctionStatement& node)
{
auto clone = std::make_unique<DeclareFunctionStatement>();
clone->attributes = node.attributes;
clone->name = node.name;
clone->parameters = node.parameters;
clone->returnType = node.returnType;
clone->statements.reserve(node.statements.size());
for (auto& statement : node.statements)
clone->statements.push_back(CloneStatement(statement));
return clone;
}
void AstCloner::Visit(AccessMemberExpression& node)
{
auto clone = std::make_unique<AccessMemberExpression>();
@ -162,21 +177,20 @@ namespace Nz::ShaderAst
PushStatement(std::move(clone));
}
void AstCloner::Visit(DeclareFunctionStatement& node)
void AstCloner::Visit(DeclareExternalStatement& node)
{
auto clone = std::make_unique<DeclareFunctionStatement>();
auto clone = std::make_unique<DeclareExternalStatement>();
clone->attributes = node.attributes;
clone->name = node.name;
clone->parameters = node.parameters;
clone->returnType = node.returnType;
clone->statements.reserve(node.statements.size());
for (auto& statement : node.statements)
clone->statements.push_back(CloneStatement(statement));
clone->externalVars = node.externalVars;
PushStatement(std::move(clone));
}
void AstCloner::Visit(DeclareFunctionStatement& node)
{
PushStatement(Clone(node));
}
void AstCloner::Visit(DeclareStructStatement& node)
{
auto clone = std::make_unique<DeclareStructStatement>();

View File

@ -9,16 +9,16 @@
namespace Nz::ShaderAst
{
ShaderExpressionType ExpressionTypeVisitor::GetExpressionType(Expression& expression, AstCache* cache = nullptr)
ExpressionType ExpressionTypeVisitor::GetExpressionType(Expression& expression, AstCache* cache)
{
m_cache = cache;
ShaderExpressionType type = GetExpressionTypeInternal(expression);
ExpressionType type = GetExpressionTypeInternal(expression);
m_cache = nullptr;
return type;
}
ShaderExpressionType ExpressionTypeVisitor::GetExpressionTypeInternal(Expression& expression)
ExpressionType ExpressionTypeVisitor::GetExpressionTypeInternal(Expression& expression)
{
m_lastExpressionType.reset();
@ -28,6 +28,33 @@ namespace Nz::ShaderAst
return std::move(*m_lastExpressionType);
}
ExpressionType ExpressionTypeVisitor::ResolveAlias(Expression& expression, ExpressionType expressionType)
{
if (IsIdentifierType(expressionType))
{
auto scopeIt = m_cache->scopeIdByNode.find(&expression);
if (scopeIt == m_cache->scopeIdByNode.end())
throw std::runtime_error("internal error");
const AstCache::Identifier* identifier = m_cache->FindIdentifier(scopeIt->second, std::get<IdentifierType>(expressionType).name);
if (identifier && std::holds_alternative<AstCache::Alias>(identifier->value))
{
const AstCache::Alias& alias = std::get<AstCache::Alias>(identifier->value);
return std::visit([&](auto&& arg) -> ShaderAst::ExpressionType
{
using T = std::decay_t<decltype(arg)>;
if constexpr (std::is_same_v<T, ExpressionType>)
return arg;
else
static_assert(AlwaysFalse<T>::value, "non-exhaustive visitor");
}, alias.value);
}
}
return expressionType;
}
void ExpressionTypeVisitor::Visit(Expression& expression)
{
if (m_cache)
@ -51,6 +78,16 @@ namespace Nz::ShaderAst
void ExpressionTypeVisitor::Visit(AccessMemberExpression& node)
{
auto scopeIt = m_cache->scopeIdByNode.find(&node);
if (scopeIt == m_cache->scopeIdByNode.end())
throw std::runtime_error("internal error");
ExpressionType expressionType = ResolveAlias(node, GetExpressionTypeInternal(*node.structExpr));
if (!IsIdentifierType(expressionType))
throw std::runtime_error("internal error");
const AstCache::Identifier* identifier = m_cache->FindIdentifier(scopeIt->second, std::get<IdentifierType>(expressionType).name);
throw std::runtime_error("unhandled accessmember expression");
}
@ -70,38 +107,35 @@ namespace Nz::ShaderAst
case BinaryType::Divide:
case BinaryType::Multiply:
{
ShaderExpressionType leftExprType = GetExpressionTypeInternal(*node.left);
assert(IsBasicType(leftExprType));
ExpressionType leftExprType = ResolveAlias(node, GetExpressionTypeInternal(*node.left));
ExpressionType rightExprType = ResolveAlias(node, GetExpressionTypeInternal(*node.right));
ShaderExpressionType rightExprType = GetExpressionTypeInternal(*node.right);
assert(IsBasicType(rightExprType));
switch (std::get<BasicType>(leftExprType))
if (IsPrimitiveType(leftExprType))
{
case BasicType::Boolean:
case BasicType::Float2:
case BasicType::Float3:
case BasicType::Float4:
case BasicType::Int2:
case BasicType::Int3:
case BasicType::Int4:
case BasicType::UInt2:
case BasicType::UInt3:
case BasicType::UInt4:
m_lastExpressionType = std::move(leftExprType);
break;
switch (std::get<PrimitiveType>(leftExprType))
{
case PrimitiveType::Boolean:
m_lastExpressionType = std::move(leftExprType);
break;
case BasicType::Float1:
case BasicType::Int1:
case BasicType::Mat4x4:
case BasicType::UInt1:
m_lastExpressionType = std::move(rightExprType);
break;
case BasicType::Sampler2D:
case BasicType::Void:
break;
case PrimitiveType::Float32:
case PrimitiveType::Int32:
case PrimitiveType::UInt32:
m_lastExpressionType = std::move(rightExprType);
break;
}
}
else if (IsMatrixType(leftExprType))
{
if (IsVectorType(rightExprType))
m_lastExpressionType = std::move(rightExprType);
else
m_lastExpressionType = std::move(leftExprType);
}
else if (IsVectorType(leftExprType))
m_lastExpressionType = std::move(leftExprType);
else
throw std::runtime_error("validation failure");
break;
}
@ -112,7 +146,7 @@ namespace Nz::ShaderAst
case BinaryType::CompLe:
case BinaryType::CompLt:
case BinaryType::CompNe:
m_lastExpressionType = BasicType::Boolean;
m_lastExpressionType = PrimitiveType::Boolean;
break;
}
}
@ -124,38 +158,38 @@ namespace Nz::ShaderAst
void ExpressionTypeVisitor::Visit(ConditionalExpression& node)
{
ShaderExpressionType leftExprType = GetExpressionTypeInternal(*node.truePath);
assert(leftExprType == GetExpressionTypeInternal(*node.falsePath));
ExpressionType leftExprType = ResolveAlias(node, GetExpressionTypeInternal(*node.truePath));
assert(leftExprType == ResolveAlias(node, GetExpressionTypeInternal(*node.falsePath)));
m_lastExpressionType = std::move(leftExprType);
}
void ExpressionTypeVisitor::Visit(ConstantExpression& node)
{
m_lastExpressionType = std::visit([&](auto&& arg)
m_lastExpressionType = std::visit([&](auto&& arg) -> ShaderAst::ExpressionType
{
using T = std::decay_t<decltype(arg)>;
if constexpr (std::is_same_v<T, bool>)
return BasicType::Boolean;
return PrimitiveType::Boolean;
else if constexpr (std::is_same_v<T, float>)
return BasicType::Float1;
return PrimitiveType::Float32;
else if constexpr (std::is_same_v<T, Int32>)
return BasicType::Int1;
return PrimitiveType::Int32;
else if constexpr (std::is_same_v<T, UInt32>)
return BasicType::Int1;
return PrimitiveType::UInt32;
else if constexpr (std::is_same_v<T, Vector2f>)
return BasicType::Float2;
return VectorType{ 2, PrimitiveType::Float32 };
else if constexpr (std::is_same_v<T, Vector3f>)
return BasicType::Float3;
return VectorType{ 3, PrimitiveType::Float32 };
else if constexpr (std::is_same_v<T, Vector4f>)
return BasicType::Float4;
return VectorType{ 4, PrimitiveType::Float32 };
else if constexpr (std::is_same_v<T, Vector2i32>)
return BasicType::Int2;
return VectorType{ 2, PrimitiveType::Int32 };
else if constexpr (std::is_same_v<T, Vector3i32>)
return BasicType::Int3;
return VectorType{ 3, PrimitiveType::Int32 };
else if constexpr (std::is_same_v<T, Vector4i32>)
return BasicType::Int4;
return VectorType{ 4, PrimitiveType::Int32 };
else
static_assert(AlwaysFalse<T>::value, "non-exhaustive visitor");
}, node.value);
@ -173,7 +207,7 @@ namespace Nz::ShaderAst
if (!identifier || !std::holds_alternative<AstCache::Variable>(identifier->value))
throw std::runtime_error("internal error");
m_lastExpressionType = std::get<AstCache::Variable>(identifier->value).type;
m_lastExpressionType = ResolveAlias(node, std::get<AstCache::Variable>(identifier->value).type);
}
void ExpressionTypeVisitor::Visit(IntrinsicExpression& node)
@ -185,16 +219,40 @@ namespace Nz::ShaderAst
break;
case IntrinsicType::DotProduct:
m_lastExpressionType = BasicType::Float1;
m_lastExpressionType = PrimitiveType::Float32;
break;
case IntrinsicType::SampleTexture:
{
if (node.parameters.empty())
throw std::runtime_error("validation failure");
ExpressionType firstParamType = ResolveAlias(node, GetExpressionTypeInternal(*node.parameters.front()));
if (!IsSamplerType(firstParamType))
throw std::runtime_error("validation failure");
const auto& sampler = std::get<SamplerType>(firstParamType);
m_lastExpressionType = VectorType{
4,
sampler.sampledType
};
break;
}
}
}
void ExpressionTypeVisitor::Visit(SwizzleExpression& node)
{
ShaderExpressionType exprType = GetExpressionTypeInternal(*node.expression);
assert(IsBasicType(exprType));
ExpressionType exprType = GetExpressionTypeInternal(*node.expression);
m_lastExpressionType = static_cast<BasicType>(UnderlyingCast(GetComponentType(std::get<BasicType>(exprType))) + node.componentCount - 1);
if (IsMatrixType(exprType))
m_lastExpressionType = std::get<MatrixType>(exprType).type;
else if (IsVectorType(exprType))
m_lastExpressionType = std::get<VectorType>(exprType).type;
else
throw std::runtime_error("validation failure");
}
}

View File

@ -453,8 +453,8 @@ namespace Nz::ShaderAst
{
auto& constant = static_cast<ConstantExpression&>(*cond);
assert(IsBasicType(GetExpressionType(constant)));
assert(std::get<BasicType>(GetExpressionType(constant)) == BasicType::Boolean);
assert(IsPrimitiveType(GetExpressionType(constant)));
assert(std::get<PrimitiveType>(GetExpressionType(constant)) == PrimitiveType::Boolean);
bool cValue = std::get<bool>(constant.value);
if (!cValue)

View File

@ -79,6 +79,11 @@ namespace Nz::ShaderAst
node.statement->Visit(*this);
}
void AstRecursiveVisitor::Visit(DeclareExternalStatement& node)
{
/* Nothing to do */
}
void AstRecursiveVisitor::Visit(DeclareFunctionStatement& node)
{
for (auto& statement : node.statements)

View File

@ -58,7 +58,7 @@ namespace Nz::ShaderAst
void AstSerializerBase::Serialize(CastExpression& node)
{
Enum(node.targetType);
Type(node.targetType);
for (auto& expr : node.expressions)
Node(expr);
}
@ -152,17 +152,25 @@ namespace Nz::ShaderAst
Node(node.statement);
}
void AstSerializerBase::Serialize(DeclareExternalStatement& node)
{
Attributes(node.attributes);
Container(node.externalVars);
for (auto& extVar : node.externalVars)
{
Attributes(extVar.attributes);
Value(extVar.name);
Type(extVar.type);
}
}
void AstSerializerBase::Serialize(DeclareFunctionStatement& node)
{
Value(node.name);
Type(node.returnType);
Container(node.attributes);
for (auto& attribute : node.attributes)
{
Enum(attribute.type);
Value(attribute.args);
}
Attributes(node.attributes);
Container(node.parameters);
for (auto& parameter : node.parameters)
@ -223,6 +231,78 @@ namespace Nz::ShaderAst
m_stream.FlushBits();
}
void AstSerializerBase::Attributes(std::vector<Attribute>& attributes)
{
Container(attributes);
for (auto& attribute : attributes)
{
Enum(attribute.type);
if (IsWriting())
{
std::visit([&](auto&& arg)
{
using T = std::decay_t<decltype(arg)>;
if constexpr (std::is_same_v<T, std::monostate>)
{
UInt8 typeId = 0;
Value(typeId);
}
else if constexpr (std::is_same_v<T, long long>)
{
UInt8 typeId = 1;
UInt64 v = UInt64(arg);
Value(typeId);
Value(v);
}
else if constexpr (std::is_same_v<T, std::string>)
{
UInt8 typeId = 2;
Value(typeId);
Value(arg);
}
else
static_assert(AlwaysFalse<T>::value, "non-exhaustive visitor");
}, attribute.args);
}
else
{
UInt8 typeId;
Value(typeId);
switch (typeId)
{
case 0:
attribute.args.emplace<std::monostate>();
break;
case 1:
{
UInt64 arg;
Value(arg);
attribute.args = static_cast<long long>(arg);
break;
}
case 2:
{
std::string arg;
Value(arg);
attribute.args = std::move(arg);
break;
}
default:
throw std::runtime_error("invalid attribute type id");
}
}
}
}
bool ShaderAstSerializer::IsWriting() const
{
@ -253,20 +333,47 @@ namespace Nz::ShaderAst
}
}
void ShaderAstSerializer::Type(ShaderExpressionType& type)
void ShaderAstSerializer::Type(ExpressionType& type)
{
std::visit([&](auto&& arg)
{
using T = std::decay_t<decltype(arg)>;
if constexpr (std::is_same_v<T, BasicType>)
{
if constexpr (std::is_same_v<T, NoType>)
m_stream << UInt8(0);
m_stream << UInt32(arg);
}
else if constexpr (std::is_same_v<T, std::string>)
else if constexpr (std::is_same_v<T, PrimitiveType>)
{
m_stream << UInt8(1);
m_stream << arg;
m_stream << UInt32(arg);
}
else if constexpr (std::is_same_v<T, IdentifierType>)
{
m_stream << UInt8(2);
m_stream << arg.name;
}
else if constexpr (std::is_same_v<T, MatrixType>)
{
m_stream << UInt8(3);
m_stream << UInt32(arg.columnCount);
m_stream << UInt32(arg.rowCount);
m_stream << UInt32(arg.type);
}
else if constexpr (std::is_same_v<T, SamplerType>)
{
m_stream << UInt8(4);
m_stream << UInt32(arg.dim);
m_stream << UInt32(arg.sampledType);
}
else if constexpr (std::is_same_v<T, UniformType>)
{
m_stream << UInt8(5);
m_stream << arg.containedType.name;
}
else if constexpr (std::is_same_v<T, VectorType>)
{
m_stream << UInt8(6);
m_stream << UInt32(arg.componentCount);
m_stream << UInt32(arg.type);
}
else
static_assert(AlwaysFalse<T>::value, "non-exhaustive visitor");
@ -421,28 +528,123 @@ namespace Nz::ShaderAst
}
}
void ShaderAstUnserializer::Type(ShaderExpressionType& type)
void ShaderAstUnserializer::Type(ExpressionType& type)
{
UInt8 typeIndex;
Value(typeIndex);
switch (typeIndex)
{
case 0: //< Primitive
/*
if constexpr (std::is_same_v<T, NoType>)
m_stream << UInt8(0);
else if constexpr (std::is_same_v<T, PrimitiveType>)
{
BasicType exprType;
Enum(exprType);
m_stream << UInt8(1);
m_stream << UInt32(arg);
}
else if constexpr (std::is_same_v<T, IdentifierType>)
{
m_stream << UInt8(2);
m_stream << arg.name;
}
else if constexpr (std::is_same_v<T, MatrixType>)
{
m_stream << UInt8(3);
m_stream << UInt32(arg.columnCount);
m_stream << UInt32(arg.rowCount);
m_stream << UInt32(arg.type);
}
else if constexpr (std::is_same_v<T, SamplerType>)
{
m_stream << UInt8(4);
m_stream << UInt32(arg.dim);
m_stream << UInt32(arg.sampledType);
}
else if constexpr (std::is_same_v<T, VectorType>)
{
m_stream << UInt8(5);
m_stream << UInt32(arg.componentCount);
m_stream << UInt32(arg.type);
}
*/
type = exprType;
case 0: //< NoType
type = NoType{};
break;
case 1: //< PrimitiveType
{
PrimitiveType primitiveType;
Enum(primitiveType);
type = primitiveType;
break;
}
case 1: //< Struct (name)
case 2: //< Identifier
{
std::string structName;
Value(structName);
std::string identifier;
Value(identifier);
type = std::move(structName);
type = IdentifierType{ std::move(identifier) };
break;
}
case 3: //< MatrixType
{
UInt32 columnCount, rowCount;
PrimitiveType primitiveType;
Value(columnCount);
Value(rowCount);
Enum(primitiveType);
type = MatrixType {
columnCount,
rowCount,
primitiveType
};
break;
}
case 4: //< SamplerType
{
ImageType dim;
PrimitiveType sampledType;
Enum(dim);
Enum(sampledType);
type = SamplerType {
dim,
sampledType
};
break;
}
case 5: //< UniformType
{
std::string containedType;
Value(containedType);
type = UniformType {
IdentifierType {
containedType
}
};
break;
}
case 6: //< VectorType
{
UInt32 componentCount;
PrimitiveType componentType;
Value(componentCount);
Enum(componentType);
type = VectorType{
componentCount,
componentType
};
break;
}

View File

@ -18,7 +18,6 @@ namespace Nz::ShaderAst
{ "frag", ShaderStageType::Fragment },
{ "vert", ShaderStageType::Vertex },
};
}
struct AstError
@ -30,6 +29,8 @@ namespace Nz::ShaderAst
{
//const ShaderAst::Function* currentFunction;
std::optional<std::size_t> activeScopeId;
std::unordered_set<std::string> declaredExternalVar;
std::unordered_set<long long> usedBindingIndexes;;
AstCache* cache;
};
@ -81,31 +82,31 @@ namespace Nz::ShaderAst
return TypeMustMatch(GetExpressionType(*left, m_context->cache), GetExpressionType(*right, m_context->cache));
}
void AstValidator::TypeMustMatch(const ShaderExpressionType& left, const ShaderExpressionType& right)
void AstValidator::TypeMustMatch(const ExpressionType& left, const ExpressionType& right)
{
if (left != right)
throw AstError{ "Left expression type must match right expression type" };
}
ShaderExpressionType AstValidator::CheckField(const std::string& structName, const std::string* memberIdentifier, std::size_t remainingMembers)
ExpressionType AstValidator::CheckField(const std::string& structName, const std::string* memberIdentifier, std::size_t remainingMembers)
{
const AstCache::Identifier* identifier = m_context->cache->FindIdentifier(*m_context->activeScopeId, structName);
if (!identifier)
throw AstError{ "unknown identifier " + structName };
if (std::holds_alternative<StructDescription>(identifier->value))
if (!std::holds_alternative<StructDescription>(identifier->value))
throw AstError{ "identifier is not a struct" };
const StructDescription& s = std::get<StructDescription>(identifier->value);
auto memberIt = std::find_if(s.members.begin(), s.members.begin(), [&](const auto& field) { return field.name == memberIdentifier[0]; });
auto memberIt = std::find_if(s.members.begin(), s.members.end(), [&](const auto& field) { return field.name == memberIdentifier[0]; });
if (memberIt == s.members.end())
throw AstError{ "unknown field " + memberIdentifier[0]};
const auto& member = *memberIt;
if (remainingMembers > 1)
return CheckField(std::get<std::string>(member.type), memberIdentifier + 1, remainingMembers - 1);
return CheckField(std::get<IdentifierType>(member.type).name, memberIdentifier + 1, remainingMembers - 1);
else
return member.type;
}
@ -130,7 +131,7 @@ namespace Nz::ShaderAst
m_context->activeScopeId = previousScope.parentScopeIndex;
}
void AstValidator::RegisterExpressionType(Expression& node, ShaderExpressionType expressionType)
void AstValidator::RegisterExpressionType(Expression& node, ExpressionType expressionType)
{
m_context->cache->nodeExpressionType[&node] = std::move(expressionType);
}
@ -145,11 +146,14 @@ namespace Nz::ShaderAst
{
RegisterScope(node);
ShaderExpressionType exprType = GetExpressionType(MandatoryExpr(node.structExpr), m_context->cache);
if (!IsStructType(exprType))
// Register expressions types
AstRecursiveVisitor::Visit(node);
ExpressionType exprType = GetExpressionType(MandatoryExpr(node.structExpr), m_context->cache);
if (!IsIdentifierType(exprType))
throw AstError{ "expression is not a structure" };
const std::string& structName = std::get<std::string>(exprType);
const std::string& structName = std::get<IdentifierType>(exprType).name;
RegisterExpressionType(node, CheckField(structName, node.memberIdentifiers.data(), node.memberIdentifiers.size()));
}
@ -160,12 +164,14 @@ namespace Nz::ShaderAst
MandatoryExpr(node.left);
MandatoryExpr(node.right);
// Register expressions types
AstRecursiveVisitor::Visit(node);
TypeMustMatch(node.left, node.right);
if (GetExpressionCategory(*node.left) != ExpressionCategory::LValue)
throw AstError { "Assignation is only possible with a l-value" };
AstRecursiveVisitor::Visit(node);
}
void AstValidator::Visit(BinaryExpression& node)
@ -175,80 +181,121 @@ namespace Nz::ShaderAst
// Register expression type
AstRecursiveVisitor::Visit(node);
ShaderExpressionType leftExprType = GetExpressionType(MandatoryExpr(node.left), m_context->cache);
if (!IsBasicType(leftExprType))
ExpressionType leftExprType = GetExpressionType(MandatoryExpr(node.left), m_context->cache);
if (!IsPrimitiveType(leftExprType) && !IsMatrixType(leftExprType) && !IsVectorType(leftExprType))
throw AstError{ "left expression type does not support binary operation" };
ShaderExpressionType rightExprType = GetExpressionType(MandatoryExpr(node.right), m_context->cache);
if (!IsBasicType(rightExprType))
ExpressionType rightExprType = GetExpressionType(MandatoryExpr(node.right), m_context->cache);
if (!IsPrimitiveType(rightExprType) && !IsMatrixType(rightExprType) && !IsVectorType(rightExprType))
throw AstError{ "right expression type does not support binary operation" };
BasicType leftType = std::get<BasicType>(leftExprType);
BasicType rightType = std::get<BasicType>(rightExprType);
switch (node.op)
if (IsPrimitiveType(leftExprType))
{
case BinaryType::CompGe:
case BinaryType::CompGt:
case BinaryType::CompLe:
case BinaryType::CompLt:
if (leftType == BasicType::Boolean)
throw AstError{ "this operation is not supported for booleans" };
[[fallthrough]];
case BinaryType::Add:
case BinaryType::CompEq:
case BinaryType::CompNe:
case BinaryType::Subtract:
TypeMustMatch(node.left, node.right);
break;
case BinaryType::Multiply:
case BinaryType::Divide:
PrimitiveType leftType = std::get<PrimitiveType>(leftExprType);
switch (node.op)
{
switch (leftType)
case BinaryType::CompGe:
case BinaryType::CompGt:
case BinaryType::CompLe:
case BinaryType::CompLt:
if (leftType == PrimitiveType::Boolean)
throw AstError{ "this operation is not supported for booleans" };
[[fallthrough]];
case BinaryType::Add:
case BinaryType::CompEq:
case BinaryType::CompNe:
case BinaryType::Subtract:
TypeMustMatch(node.left, node.right);
break;
case BinaryType::Multiply:
case BinaryType::Divide:
{
case BasicType::Float1:
case BasicType::Int1:
switch (leftType)
{
if (GetComponentType(rightType) != leftType)
throw AstError{ "Left expression type is not compatible with right expression type" };
break;
}
case BasicType::Float2:
case BasicType::Float3:
case BasicType::Float4:
case BasicType::Int2:
case BasicType::Int3:
case BasicType::Int4:
{
if (leftType != rightType && rightType != GetComponentType(leftType))
throw AstError{ "Left expression type is not compatible with right expression type" };
break;
}
case BasicType::Mat4x4:
{
switch (rightType)
case PrimitiveType::Float32:
case PrimitiveType::Int32:
case PrimitiveType::UInt32:
{
case BasicType::Float1:
case BasicType::Float4:
case BasicType::Mat4x4:
break;
if (IsMatrixType(rightExprType))
TypeMustMatch(leftType, std::get<MatrixType>(rightExprType).type);
else if (IsVectorType(rightExprType))
TypeMustMatch(leftType, std::get<VectorType>(rightExprType).type);
else
throw AstError{ "incompatible types" };
default:
TypeMustMatch(node.left, node.right);
break;
}
break;
}
case PrimitiveType::Boolean:
throw AstError{ "this operation is not supported for booleans" };
default:
TypeMustMatch(node.left, node.right);
break;
default:
throw AstError{ "incompatible types" };
}
}
}
}
else if (IsMatrixType(leftExprType))
{
const MatrixType& leftType = std::get<MatrixType>(leftExprType);
switch (node.op)
{
case BinaryType::CompGe:
case BinaryType::CompGt:
case BinaryType::CompLe:
case BinaryType::CompLt:
case BinaryType::CompEq:
case BinaryType::CompNe:
case BinaryType::Add:
case BinaryType::Subtract:
TypeMustMatch(node.left, node.right);
break;
case BinaryType::Multiply:
case BinaryType::Divide:
{
if (IsMatrixType(rightExprType))
TypeMustMatch(leftExprType, rightExprType);
else if (IsPrimitiveType(rightExprType))
TypeMustMatch(leftType.type, rightExprType);
else if (IsVectorType(rightExprType))
{
const VectorType& rightType = std::get<VectorType>(rightExprType);
TypeMustMatch(leftType.type, rightType.type);
if (leftType.columnCount != rightType.componentCount)
throw AstError{ "incompatible types" };
}
else
throw AstError{ "incompatible types" };
}
}
}
else if (IsVectorType(leftExprType))
{
const MatrixType& leftType = std::get<MatrixType>(leftExprType);
switch (node.op)
{
case BinaryType::CompGe:
case BinaryType::CompGt:
case BinaryType::CompLe:
case BinaryType::CompLt:
case BinaryType::CompEq:
case BinaryType::CompNe:
case BinaryType::Add:
case BinaryType::Subtract:
TypeMustMatch(node.left, node.right);
break;
case BinaryType::Multiply:
case BinaryType::Divide:
{
if (IsPrimitiveType(rightExprType))
TypeMustMatch(leftType.type, rightExprType);
else
throw AstError{ "incompatible types" };
}
}
}
@ -258,24 +305,35 @@ namespace Nz::ShaderAst
{
RegisterScope(node);
AstRecursiveVisitor::Visit(node);
auto GetComponentCount = [](const ExpressionType& exprType) -> unsigned int
{
if (IsPrimitiveType(exprType))
return 1;
else if (IsVectorType(exprType))
return std::get<VectorType>(exprType).componentCount;
else
throw AstError{ "wut" };
};
unsigned int componentCount = 0;
unsigned int requiredComponents = GetComponentCount(node.targetType);
for (auto& exprPtr : node.expressions)
{
if (!exprPtr)
break;
ShaderExpressionType exprType = GetExpressionType(*exprPtr, m_context->cache);
if (!IsBasicType(exprType))
ExpressionType exprType = GetExpressionType(*exprPtr, m_context->cache);
if (!IsPrimitiveType(exprType) && !IsVectorType(exprType))
throw AstError{ "incompatible type" };
componentCount += GetComponentCount(std::get<BasicType>(exprType));
componentCount += GetComponentCount(exprType);
}
if (componentCount != requiredComponents)
throw AstError{ "component count doesn't match required component count" };
AstRecursiveVisitor::Visit(node);
}
void AstValidator::Visit(ConditionalExpression& node)
@ -313,34 +371,51 @@ namespace Nz::ShaderAst
{
RegisterScope(node);
AstRecursiveVisitor::Visit(node);
switch (node.intrinsic)
{
case IntrinsicType::CrossProduct:
case IntrinsicType::DotProduct:
{
if (node.parameters.size() != 2)
throw AstError { "Expected 2 parameters" };
throw AstError { "Expected two parameters" };
for (auto& param : node.parameters)
MandatoryExpr(param);
ShaderExpressionType type = GetExpressionType(*node.parameters.front(), m_context->cache);
ExpressionType type = GetExpressionType(*node.parameters.front(), m_context->cache);
for (std::size_t i = 1; i < node.parameters.size(); ++i)
{
if (type != GetExpressionType(MandatoryExpr(node.parameters[i])), m_context->cache)
if (type != GetExpressionType(MandatoryExpr(node.parameters[i]), m_context->cache))
throw AstError{ "All type must match" };
}
break;
}
case IntrinsicType::SampleTexture:
{
if (node.parameters.size() != 2)
throw AstError{ "Expected two parameters" };
for (auto& param : node.parameters)
MandatoryExpr(param);
if (!IsSamplerType(GetExpressionType(*node.parameters[0], m_context->cache)))
throw AstError{ "First parameter must be a sampler" };
if (!IsVectorType(GetExpressionType(*node.parameters[1], m_context->cache)))
throw AstError{ "First parameter must be a vector" };
}
}
switch (node.intrinsic)
{
case IntrinsicType::CrossProduct:
{
if (GetExpressionType(*node.parameters[0]) != ShaderExpressionType{ BasicType::Float3 }, m_context->cache)
throw AstError{ "CrossProduct only works with Float3 expressions" };
if (GetExpressionType(*node.parameters[0]) != ExpressionType{ VectorType{ 3, PrimitiveType::Float32 } })
throw AstError{ "CrossProduct only works with vec3<f32> expressions" };
break;
}
@ -348,8 +423,6 @@ namespace Nz::ShaderAst
case IntrinsicType::DotProduct:
break;
}
AstRecursiveVisitor::Visit(node);
}
void AstValidator::Visit(SwizzleExpression& node)
@ -359,26 +432,10 @@ namespace Nz::ShaderAst
if (node.componentCount > 4)
throw AstError{ "Cannot swizzle more than four elements" };
ShaderExpressionType exprType = GetExpressionType(MandatoryExpr(node.expression), m_context->cache);
if (!IsBasicType(exprType))
ExpressionType exprType = GetExpressionType(MandatoryExpr(node.expression), m_context->cache);
if (!IsPrimitiveType(exprType) && !IsVectorType(exprType))
throw AstError{ "Cannot swizzle this type" };
switch (std::get<BasicType>(exprType))
{
case BasicType::Float1:
case BasicType::Float2:
case BasicType::Float3:
case BasicType::Float4:
case BasicType::Int1:
case BasicType::Int2:
case BasicType::Int3:
case BasicType::Int4:
break;
default:
throw AstError{ "Cannot swizzle this type" };
}
AstRecursiveVisitor::Visit(node);
}
@ -388,8 +445,8 @@ namespace Nz::ShaderAst
for (auto& condStatement : node.condStatements)
{
ShaderExpressionType condType = GetExpressionType(MandatoryExpr(condStatement.condition), m_context->cache);
if (!IsBasicType(condType) || std::get<BasicType>(condType) != BasicType::Boolean)
ExpressionType condType = GetExpressionType(MandatoryExpr(condStatement.condition), m_context->cache);
if (!IsPrimitiveType(condType) || std::get<PrimitiveType>(condType) != PrimitiveType::Boolean)
throw AstError{ "if expression must resolve to boolean type" };
MandatoryStatement(condStatement.statement);
@ -409,6 +466,78 @@ namespace Nz::ShaderAst
// throw AstError{ "condition not found" };
}
void AstValidator::Visit(DeclareExternalStatement& node)
{
RegisterScope(node);
auto& scope = m_context->cache->scopes[*m_context->activeScopeId];
for (const auto& [attributeType, arg] : node.attributes)
{
switch (attributeType)
{
default:
throw AstError{ "unhandled attribute for external block" };
}
}
for (const auto& extVar : node.externalVars)
{
bool hasBinding = false;
bool hasLayout = false;
for (const auto& [attributeType, arg] : extVar.attributes)
{
switch (attributeType)
{
case AttributeType::Binding:
{
if (hasBinding)
throw AstError{ "attribute binding must be present once" };
if (!std::holds_alternative<long long>(arg))
throw AstError{ "attribute binding requires a string parameter" };
long long bindingIndex = std::get<long long>(arg);
if (m_context->usedBindingIndexes.find(bindingIndex) != m_context->usedBindingIndexes.end())
throw AstError{ "Binding #" + std::to_string(bindingIndex) + " is already in use" };
m_context->usedBindingIndexes.insert(bindingIndex);
break;
}
case AttributeType::Layout:
{
if (hasLayout)
throw AstError{ "attribute layout must be present once" };
if (!std::holds_alternative<std::string>(arg))
throw AstError{ "attribute layout requires a string parameter" };
if (std::get<std::string>(arg) != "std140")
throw AstError{ "unknow layout type" };
hasLayout = true;
break;
}
default:
throw AstError{ "unhandled attribute for external variable" };
}
}
if (m_context->declaredExternalVar.find(extVar.name) != m_context->declaredExternalVar.end())
throw AstError{ "External variable " + extVar.name + " is already declared" };
m_context->declaredExternalVar.insert(extVar.name);
ExpressionType subType = extVar.type;
if (IsUniformType(subType))
subType = IdentifierType{ std::get<UniformType>(subType).containedType };
auto& identifier = scope.identifiers.emplace_back();
identifier = AstCache::Identifier{ extVar.name, AstCache::Variable { std::move(subType) } };
}
}
void AstValidator::Visit(DeclareFunctionStatement& node)
{
bool hasEntry = false;
@ -421,12 +550,14 @@ namespace Nz::ShaderAst
if (hasEntry)
throw AstError{ "attribute entry must be present once" };
if (arg.empty())
throw AstError{ "attribute entry requires a parameter" };
if (!std::holds_alternative<std::string>(arg))
throw AstError{ "attribute entry requires a string parameter" };
auto it = entryPoints.find(arg);
const std::string& argStr = std::get<std::string>(arg);
auto it = entryPoints.find(argStr);
if (it == entryPoints.end())
throw AstError{ "invalid parameter " + arg + " for entry attribute" };
throw AstError{ "invalid parameter " + argStr + " for entry attribute" };
ShaderStageType stageType = it->second;
@ -435,6 +566,9 @@ namespace Nz::ShaderAst
m_context->cache->entryFunctions[UnderlyingCast(it->second)] = &node;
if (node.parameters.size() > 1)
throw AstError{ "entry functions can either take one struct parameter or no parameter" };
hasEntry = true;
break;
}
@ -468,6 +602,8 @@ namespace Nz::ShaderAst
RegisterScope(node);
//TODO: check members attributes
auto& scope = m_context->cache->scopes[*m_context->activeScopeId];
auto& identifier = scope.identifiers.emplace_back();

View File

@ -36,22 +36,24 @@ namespace Nz::ShaderLang
std::vector<Token> Tokenize(const std::string_view& str)
{
// Can't use std::from_chars for double thanks to libc++ and libstdc++ developers for being lazy
// Can't use std::from_chars for double, thanks to libc++ and libstdc++ developers for being lazy
ForceCLocale forceCLocale;
std::unordered_map<std::string, TokenType> reservedKeywords = {
{ "false", TokenType::BoolFalse },
{ "fn", TokenType::FunctionDeclaration },
{ "let", TokenType::Let },
{ "return", TokenType::Return },
{ "true", TokenType::BoolTrue }
{ "external", TokenType::External },
{ "false", TokenType::BoolFalse },
{ "fn", TokenType::FunctionDeclaration },
{ "let", TokenType::Let },
{ "return", TokenType::Return },
{ "struct", TokenType::Struct },
{ "true", TokenType::BoolTrue }
};
std::size_t currentPos = 0;
auto Peek = [&](std::size_t advance = 1) -> char
{
if (currentPos + advance < str.size())
if (currentPos + advance < str.size() && str[currentPos + advance] != '\0')
return str[currentPos + advance];
else
return char(-1);
@ -134,7 +136,10 @@ namespace Nz::ShaderLang
{
currentPos++;
if (Peek() == '/')
{
currentPos++;
break;
}
}
else if (next == '\n')
{
@ -250,7 +255,48 @@ namespace Nz::ShaderLang
break;
}
case '=': tokenType = TokenType::Assign; break;
case '=':
{
char next = Peek();
if (next == '=')
{
currentPos++;
tokenType = TokenType::Equal;
}
else
tokenType = TokenType::Assign;
break;
}
case '<':
{
char next = Peek();
if (next == '=')
{
currentPos++;
tokenType = TokenType::LessThanEqual;
}
else
tokenType = TokenType::LessThan;
break;
}
case '>':
{
char next = Peek();
if (next == '=')
{
currentPos++;
tokenType = TokenType::GreatherThanEqual;
}
else
tokenType = TokenType::GreatherThan;
break;
}
case '+': tokenType = TokenType::Plus; break;
case '*': tokenType = TokenType::Multiply; break;
case ':': tokenType = TokenType::Colon; break;

View File

@ -11,32 +11,24 @@ namespace Nz::ShaderLang
{
namespace
{
std::unordered_map<std::string, ShaderAst::BasicType> identifierToBasicType = {
{ "bool", ShaderAst::BasicType::Boolean },
std::unordered_map<std::string, ShaderAst::PrimitiveType> identifierToBasicType = {
{ "bool", ShaderAst::PrimitiveType::Boolean },
{ "i32", ShaderAst::PrimitiveType::Int32 },
{ "f32", ShaderAst::PrimitiveType::Float32 },
{ "u32", ShaderAst::PrimitiveType::UInt32 }
};
{ "i32", ShaderAst::BasicType::Int1 },
{ "vec2i32", ShaderAst::BasicType::Int2 },
{ "vec3i32", ShaderAst::BasicType::Int3 },
{ "vec4i32", ShaderAst::BasicType::Int4 },
{ "f32", ShaderAst::BasicType::Float1 },
{ "vec2f32", ShaderAst::BasicType::Float2 },
{ "vec3f32", ShaderAst::BasicType::Float3 },
{ "vec4f32", ShaderAst::BasicType::Float4 },
{ "mat4x4f32", ShaderAst::BasicType::Mat4x4 },
{ "sampler2D", ShaderAst::BasicType::Sampler2D },
{ "void", ShaderAst::BasicType::Void },
{ "u32", ShaderAst::BasicType::UInt1 },
{ "vec2u32", ShaderAst::BasicType::UInt3 },
{ "vec3u32", ShaderAst::BasicType::UInt3 },
{ "vec4u32", ShaderAst::BasicType::UInt4 },
std::unordered_map<std::string, ShaderAst::IntrinsicType> identifierToIntrinsic = {
{ "cross", ShaderAst::IntrinsicType::CrossProduct },
{ "dot", ShaderAst::IntrinsicType::DotProduct },
};
std::unordered_map<std::string, ShaderAst::AttributeType> identifierToAttributeType = {
{ "entry", ShaderAst::AttributeType::Entry },
{ "layout", ShaderAst::AttributeType::Layout },
{ "binding", ShaderAst::AttributeType::Binding },
{ "builtin", ShaderAst::AttributeType::Builtin },
{ "entry", ShaderAst::AttributeType::Entry },
{ "layout", ShaderAst::AttributeType::Layout },
{ "location", ShaderAst::AttributeType::Location },
};
}
@ -50,22 +42,41 @@ namespace Nz::ShaderLang
m_context = &context;
std::vector<ShaderAst::Attribute> attributes;
EnterScope();
bool reachedEndOfStream = false;
while (!reachedEndOfStream)
{
const Token& nextToken = Peek();
switch (nextToken.type)
{
case TokenType::EndOfStream:
if (!attributes.empty())
throw UnexpectedToken{};
reachedEndOfStream = true;
break;
case TokenType::External:
context.root->statements.push_back(ParseExternalBlock(std::move(attributes)));
attributes.clear();
break;
case TokenType::OpenAttribute:
HandleAttributes();
assert(attributes.empty());
attributes = ParseAttributes();
break;
case TokenType::FunctionDeclaration:
context.root->statements.push_back(ParseFunctionDeclaration());
context.root->statements.push_back(ParseFunctionDeclaration(std::move(attributes)));
attributes.clear();
break;
case TokenType::EndOfStream:
reachedEndOfStream = true;
case TokenType::Struct:
context.root->statements.push_back(ParseStructDeclaration(std::move(attributes)));
attributes.clear();
break;
default:
@ -73,6 +84,8 @@ namespace Nz::ShaderLang
}
}
LeaveScope();
return std::move(context.root);
}
@ -90,6 +103,92 @@ namespace Nz::ShaderLang
m_context->tokenIndex += count;
}
ShaderAst::ExpressionType Parser::DecodeType(const std::string& identifier)
{
if (auto it = identifierToBasicType.find(identifier); it != identifierToBasicType.end())
return it->second;
//FIXME: Handle this better
if (identifier == "mat4")
{
ShaderAst::MatrixType matrixType;
matrixType.columnCount = 4;
matrixType.rowCount = 4;
Expect(Advance(), TokenType::LessThan); //< '<'
matrixType.type = ParsePrimitiveType();
Expect(Advance(), TokenType::GreatherThan); //< '>'
return matrixType;
}
else if (identifier == "sampler2D")
{
ShaderAst::SamplerType samplerType;
samplerType.dim = ImageType_2D;
Expect(Advance(), TokenType::LessThan); //< '<'
samplerType.sampledType = ParsePrimitiveType();
Expect(Advance(), TokenType::GreatherThan); //< '>'
return samplerType;
}
else if (identifier == "uniform")
{
ShaderAst::UniformType uniformType;
Expect(Advance(), TokenType::LessThan); //< '<'
uniformType.containedType = ShaderAst::IdentifierType{ ParseIdentifierAsName() };
Expect(Advance(), TokenType::GreatherThan); //< '>'
return uniformType;
}
else if (identifier == "vec2")
{
ShaderAst::VectorType vectorType;
vectorType.componentCount = 2;
Expect(Advance(), TokenType::LessThan); //< '<'
vectorType.type = ParsePrimitiveType();
Expect(Advance(), TokenType::GreatherThan); //< '>'
return vectorType;
}
else if (identifier == "vec3")
{
ShaderAst::VectorType vectorType;
vectorType.componentCount = 3;
Expect(Advance(), TokenType::LessThan); //< '<'
vectorType.type = ParsePrimitiveType();
Expect(Advance(), TokenType::GreatherThan); //< '>'
return vectorType;
}
else if (identifier == "vec4")
{
ShaderAst::VectorType vectorType;
vectorType.componentCount = 4;
Expect(Advance(), TokenType::LessThan); //< '<'
vectorType.type = ParsePrimitiveType();
Expect(Advance(), TokenType::GreatherThan); //< '>'
return vectorType;
}
else
{
ShaderAst::IdentifierType identifierType;
identifierType.name = identifier;
return identifierType;
}
}
void Parser::EnterScope()
{
m_context->scopeSizes.push_back(m_context->identifiersInScope.size());
}
const Token& Parser::Expect(const Token& token, TokenType type)
{
if (token.type != type)
@ -114,13 +213,34 @@ namespace Nz::ShaderLang
return token;
}
void Parser::LeaveScope()
{
assert(!m_context->scopeSizes.empty());
m_context->identifiersInScope.resize(m_context->scopeSizes.back());
m_context->scopeSizes.pop_back();
}
bool Parser::IsVariableInScope(const std::string_view& identifier) const
{
return std::find(m_context->identifiersInScope.rbegin(), m_context->identifiersInScope.rend(), identifier) != m_context->identifiersInScope.rend();
}
void Parser::RegisterVariable(std::string identifier)
{
if (IsVariableInScope(identifier))
throw DuplicateIdentifier{ ("identifier name " + identifier + " is already taken").c_str() };
assert(!m_context->scopeSizes.empty());
m_context->identifiersInScope.push_back(std::move(identifier));
}
const Token& Parser::Peek(std::size_t advance)
{
assert(m_context->tokenIndex + advance < m_context->tokenCount);
return m_context->tokens[m_context->tokenIndex + advance];
}
void Parser::HandleAttributes()
std::vector<ShaderAst::Attribute> Parser::ParseAttributes()
{
std::vector<ShaderAst::Attribute> attributes;
@ -150,13 +270,22 @@ namespace Nz::ShaderLang
ShaderAst::AttributeType attributeType = ParseIdentifierAsAttributeType();
std::string arg;
ShaderAst::Attribute::Param arg;
if (Peek().type == TokenType::OpenParenthesis)
{
Consume();
if (Peek().type == TokenType::Identifier)
arg = std::get<std::string>(Advance().data);
const Token& n = Peek();
if (n.type == TokenType::Identifier)
{
arg = std::get<std::string>(n.data);
Consume();
}
else if (n.type == TokenType::IntegerValue)
{
arg = std::get<long long>(n.data);
Consume();
}
Expect(Advance(), TokenType::ClosingParenthesis);
}
@ -171,16 +300,54 @@ namespace Nz::ShaderLang
Expect(Advance(), TokenType::ClosingAttribute);
const Token& nextToken = Peek();
switch (nextToken.type)
return attributes;
}
ShaderAst::StatementPtr Parser::ParseExternalBlock(std::vector<ShaderAst::Attribute> attributes)
{
Expect(Advance(), TokenType::External);
Expect(Advance(), TokenType::OpenCurlyBracket);
std::unique_ptr<ShaderAst::DeclareExternalStatement> externalStatement = std::make_unique<ShaderAst::DeclareExternalStatement>();
externalStatement->attributes = std::move(attributes);
bool first = true;
for (;;)
{
case TokenType::FunctionDeclaration:
m_context->root->statements.push_back(ParseFunctionDeclaration(std::move(attributes)));
if (!first)
{
const Token& nextToken = Peek();
if (nextToken.type == TokenType::Comma)
Consume();
else
{
Expect(nextToken, TokenType::ClosingCurlyBracket);
break;
}
}
first = false;
const Token& token = Peek();
if (token.type == TokenType::ClosingCurlyBracket)
break;
default:
throw UnexpectedToken{};
auto& extVar = externalStatement->externalVars.emplace_back();
if (token.type == TokenType::OpenAttribute)
extVar.attributes = ParseAttributes();
extVar.name = ParseIdentifierAsName();
Expect(Advance(), TokenType::Colon);
extVar.type = ParseType();
RegisterVariable(extVar.name);
}
Expect(Advance(), TokenType::ClosingCurlyBracket);
return externalStatement;
}
std::vector<ShaderAst::StatementPtr> Parser::ParseFunctionBody()
@ -216,17 +383,23 @@ namespace Nz::ShaderLang
Expect(Advance(), TokenType::ClosingParenthesis);
ShaderAst::ShaderExpressionType returnType = ShaderAst::BasicType::Void;
ShaderAst::ExpressionType returnType;
if (Peek().type == TokenType::FunctionReturn)
{
Consume();
returnType = ParseIdentifierAsType();
returnType = ParseType();
}
Expect(Advance(), TokenType::OpenCurlyBracket);
EnterScope();
for (const auto& parameter : parameters)
RegisterVariable(parameter.name);
std::vector<ShaderAst::StatementPtr> functionBody = ParseFunctionBody();
LeaveScope();
Expect(Advance(), TokenType::ClosingCurlyBracket);
return ShaderBuilder::DeclareFunction(std::move(attributes), std::move(functionName), std::move(parameters), std::move(functionBody), std::move(returnType));
@ -238,11 +411,59 @@ namespace Nz::ShaderLang
Expect(Advance(), TokenType::Colon);
ShaderAst::ShaderExpressionType parameterType = ParseIdentifierAsType();
ShaderAst::ExpressionType parameterType = ParseType();
return { parameterName, parameterType };
}
ShaderAst::StatementPtr Parser::ParseStructDeclaration(std::vector<ShaderAst::Attribute> attributes)
{
Expect(Advance(), TokenType::Struct);
ShaderAst::StructDescription description;
description.name = ParseIdentifierAsName();
Expect(Advance(), TokenType::OpenCurlyBracket);
bool first = true;
for (;;)
{
if (!first)
{
const Token& nextToken = Peek();
if (nextToken.type == TokenType::Comma)
Consume();
else
{
Expect(nextToken, TokenType::ClosingCurlyBracket);
break;
}
}
first = false;
const Token& token = Peek();
if (token.type == TokenType::ClosingCurlyBracket)
break;
auto& structField = description.members.emplace_back();
if (token.type == TokenType::OpenAttribute)
structField.attributes = ParseAttributes();
structField.name = ParseIdentifierAsName();
Expect(Advance(), TokenType::Colon);
structField.type = ParseType();
}
Expect(Advance(), TokenType::ClosingCurlyBracket);
return ShaderBuilder::DeclareStruct(std::move(attributes), std::move(description));
}
ShaderAst::StatementPtr Parser::ParseReturnStatement()
{
Expect(Advance(), TokenType::Return);
@ -265,6 +486,10 @@ namespace Nz::ShaderLang
statement = ParseVariableDeclaration();
break;
case TokenType::Identifier:
statement = ShaderBuilder::ExpressionStatement(ParseVariableAssignation());
break;
case TokenType::Return:
statement = ParseReturnStatement();
break;
@ -290,15 +515,26 @@ namespace Nz::ShaderLang
return statements;
}
ShaderAst::ExpressionPtr Parser::ParseVariableAssignation()
{
ShaderAst::ExpressionPtr left = ParseIdentifier();
Expect(Advance(), TokenType::Assign);
ShaderAst::ExpressionPtr right = ParseExpression();
return ShaderBuilder::Assign(ShaderAst::AssignType::Simple, std::move(left), std::move(right));
}
ShaderAst::StatementPtr Parser::ParseVariableDeclaration()
{
Expect(Advance(), TokenType::Let);
std::string variableName = ParseIdentifierAsName();
RegisterVariable(variableName);
Expect(Advance(), TokenType::Colon);
ShaderAst::ShaderExpressionType variableType = ParseIdentifierAsType();
ShaderAst::ExpressionType variableType = ParseType();
ShaderAst::ExpressionPtr expression;
if (Peek().type == TokenType::Assign)
@ -351,18 +587,61 @@ namespace Nz::ShaderLang
return ParseBinOpRhs(0, ParsePrimaryExpression());
}
ShaderAst::ExpressionPtr Parser::ParseFloatingPointExpression(bool minus)
{
const Token& floatingPointToken = Expect(Advance(), TokenType::FloatingPointValue);
return ShaderBuilder::Constant(((minus) ? -1.f : 1.f) * float(std::get<double>(floatingPointToken.data))); //< FIXME
}
ShaderAst::ExpressionPtr Parser::ParseIdentifier()
{
const Token& identifierToken = Expect(Advance(), TokenType::Identifier);
const std::string& identifier = std::get<std::string>(identifierToken.data);
return ShaderBuilder::Identifier(identifier);
ShaderAst::ExpressionPtr identifierExpr = ShaderBuilder::Identifier(identifier);
if (Peek().type == TokenType::Dot)
{
std::unique_ptr<ShaderAst::AccessMemberExpression> accessMemberNode = std::make_unique<ShaderAst::AccessMemberExpression>();
accessMemberNode->structExpr = std::move(identifierExpr);
do
{
Consume();
accessMemberNode->memberIdentifiers.push_back(ParseIdentifierAsName());
} while (Peek().type == TokenType::Dot);
identifierExpr = std::move(accessMemberNode);
}
return identifierExpr;
}
ShaderAst::ExpressionPtr Parser::ParseIntegerExpression()
ShaderAst::ExpressionPtr Parser::ParseIntegerExpression(bool minus)
{
const Token& integerToken = Expect(Advance(), TokenType::Identifier);
return ShaderBuilder::Constant(static_cast<Nz::Int32>(std::get<long long>(integerToken.data)));
const Token& integerToken = Expect(Advance(), TokenType::IntegerValue);
return ShaderBuilder::Constant(((minus) ? -1 : 1) * static_cast<Nz::Int32>(std::get<long long>(integerToken.data)));
}
std::vector<ShaderAst::ExpressionPtr> Parser::ParseParameters()
{
Expect(Advance(), TokenType::OpenParenthesis);
std::vector<ShaderAst::ExpressionPtr> parameters;
bool first = true;
while (Peek().type != TokenType::ClosingParenthesis)
{
if (!first)
Expect(Advance(), TokenType::Comma);
first = false;
parameters.push_back(ParseExpression());
}
Expect(Advance(), TokenType::ClosingParenthesis);
return parameters;
}
ShaderAst::ExpressionPtr Parser::ParseParenthesisExpression()
@ -388,15 +667,69 @@ namespace Nz::ShaderLang
return ShaderBuilder::Constant(true);
case TokenType::FloatingPointValue:
Consume();
return ShaderBuilder::Constant(float(std::get<double>(token.data))); //< FIXME
return ParseFloatingPointExpression();
case TokenType::Identifier:
return ParseIdentifier();
{
const std::string& identifier = std::get<std::string>(token.data);
if (auto it = identifierToIntrinsic.find(identifier); it != identifierToIntrinsic.end())
{
if (Peek(1).type == TokenType::OpenParenthesis)
{
Consume();
return ShaderBuilder::Intrinsic(it->second, ParseParameters());
}
}
if (IsVariableInScope(identifier))
{
auto node = ParseIdentifier();
if (node->GetType() == ShaderAst::NodeType::AccessMemberExpression)
{
ShaderAst::AccessMemberExpression* memberExpr = static_cast<ShaderAst::AccessMemberExpression*>(node.get());
if (!memberExpr->memberIdentifiers.empty() && memberExpr->memberIdentifiers.front() == "Sample")
{
if (Peek().type == TokenType::OpenParenthesis)
{
auto parameters = ParseParameters();
parameters.insert(parameters.begin(), std::move(memberExpr->structExpr));
return ShaderBuilder::Intrinsic(ShaderAst::IntrinsicType::SampleTexture, std::move(parameters));
}
}
}
return node;
}
Consume();
ShaderAst::ExpressionType exprType = DecodeType(identifier);
return ShaderBuilder::Cast(std::move(exprType), ParseParameters());
}
case TokenType::IntegerValue:
return ParseIntegerExpression();
case TokenType::Minus:
//< FIXME: Handle this with an unary node
if (Peek(1).type == TokenType::FloatingPointValue)
{
Consume();
return ParseFloatingPointExpression(true);
}
else if (Peek(1).type == TokenType::IntegerValue)
{
Consume();
return ParseIntegerExpression(true);
}
else
throw UnexpectedToken{};
break;
case TokenType::OpenParenthesis:
return ParseParenthesisExpression();
@ -429,7 +762,7 @@ namespace Nz::ShaderLang
return identifier;
}
ShaderAst::ShaderExpressionType Parser::ParseIdentifierAsType()
ShaderAst::PrimitiveType Parser::ParsePrimitiveType()
{
const Token& identifierToken = Expect(Advance(), TokenType::Identifier);
const std::string& identifier = std::get<std::string>(identifierToken.data);
@ -441,6 +774,23 @@ namespace Nz::ShaderLang
return it->second;
}
ShaderAst::ExpressionType Parser::ParseType()
{
// Handle () as no type
if (Peek().type == TokenType::OpenParenthesis)
{
Consume();
Expect(Advance(), TokenType::ClosingParenthesis);
return ShaderAst::NoType{};
}
const Token& identifierToken = Expect(Advance(), TokenType::Identifier);
const std::string& identifier = std::get<std::string>(identifierToken.data);
return DecodeType(identifier);
}
int Parser::GetTokenPrecedence(TokenType token)
{
switch (token)
@ -452,4 +802,5 @@ namespace Nz::ShaderLang
default: return -1;
}
}
}

View File

@ -39,18 +39,18 @@ namespace Nz
void SpirvAstVisitor::Visit(ShaderAst::BinaryExpression& node)
{
ShaderAst::ShaderExpressionType resultExprType = ShaderAst::GetExpressionType(node, m_cache);
assert(IsBasicType(resultExprType));
ShaderAst::ExpressionType resultExprType = ShaderAst::GetExpressionType(node, m_cache);
assert(IsPrimitiveType(resultExprType));
ShaderAst::ShaderExpressionType leftExprType = ShaderAst::GetExpressionType(*node.left, m_cache);
assert(IsBasicType(leftExprType));
ShaderAst::ExpressionType leftExprType = ShaderAst::GetExpressionType(*node.left, m_cache);
assert(IsPrimitiveType(leftExprType));
ShaderAst::ShaderExpressionType rightExprType = ShaderAst::GetExpressionType(*node.right, m_cache);
assert(IsBasicType(rightExprType));
ShaderAst::ExpressionType rightExprType = ShaderAst::GetExpressionType(*node.right, m_cache);
assert(IsPrimitiveType(rightExprType));
ShaderAst::BasicType resultType = std::get<ShaderAst::BasicType>(resultExprType);
ShaderAst::BasicType leftType = std::get<ShaderAst::BasicType>(leftExprType);
ShaderAst::BasicType rightType = std::get<ShaderAst::BasicType>(rightExprType);
ShaderAst::PrimitiveType resultType = std::get<ShaderAst::PrimitiveType>(resultExprType);
ShaderAst::PrimitiveType leftType = std::get<ShaderAst::PrimitiveType>(leftExprType);
ShaderAst::PrimitiveType rightType = std::get<ShaderAst::PrimitiveType>(rightExprType);
UInt32 leftOperand = EvaluateExpression(node.left);
@ -67,26 +67,26 @@ namespace Nz
{
switch (leftType)
{
case ShaderAst::BasicType::Float1:
case ShaderAst::BasicType::Float2:
case ShaderAst::BasicType::Float3:
case ShaderAst::BasicType::Float4:
case ShaderAst::BasicType::Mat4x4:
case ShaderAst::PrimitiveType::Float32:
// case ShaderAst::PrimitiveType::Float2:
// case ShaderAst::PrimitiveType::Float3:
// case ShaderAst::PrimitiveType::Float4:
// case ShaderAst::PrimitiveType::Mat4x4:
return SpirvOp::OpFAdd;
case ShaderAst::BasicType::Int1:
case ShaderAst::BasicType::Int2:
case ShaderAst::BasicType::Int3:
case ShaderAst::BasicType::Int4:
case ShaderAst::BasicType::UInt1:
case ShaderAst::BasicType::UInt2:
case ShaderAst::BasicType::UInt3:
case ShaderAst::BasicType::UInt4:
case ShaderAst::PrimitiveType::Int32:
// case ShaderAst::PrimitiveType::Int2:
// case ShaderAst::PrimitiveType::Int3:
// case ShaderAst::PrimitiveType::Int4:
case ShaderAst::PrimitiveType::UInt32:
// case ShaderAst::PrimitiveType::UInt2:
// case ShaderAst::PrimitiveType::UInt3:
// case ShaderAst::PrimitiveType::UInt4:
return SpirvOp::OpIAdd;
case ShaderAst::BasicType::Boolean:
case ShaderAst::BasicType::Sampler2D:
case ShaderAst::BasicType::Void:
case ShaderAst::PrimitiveType::Boolean:
// case ShaderAst::PrimitiveType::Sampler2D:
// case ShaderAst::PrimitiveType::Void:
break;
}
@ -97,26 +97,26 @@ namespace Nz
{
switch (leftType)
{
case ShaderAst::BasicType::Float1:
case ShaderAst::BasicType::Float2:
case ShaderAst::BasicType::Float3:
case ShaderAst::BasicType::Float4:
case ShaderAst::BasicType::Mat4x4:
case ShaderAst::PrimitiveType::Float32:
// case ShaderAst::PrimitiveType::Float2:
// case ShaderAst::PrimitiveType::Float3:
// case ShaderAst::PrimitiveType::Float4:
// case ShaderAst::PrimitiveType::Mat4x4:
return SpirvOp::OpFSub;
case ShaderAst::BasicType::Int1:
case ShaderAst::BasicType::Int2:
case ShaderAst::BasicType::Int3:
case ShaderAst::BasicType::Int4:
case ShaderAst::BasicType::UInt1:
case ShaderAst::BasicType::UInt2:
case ShaderAst::BasicType::UInt3:
case ShaderAst::BasicType::UInt4:
case ShaderAst::PrimitiveType::Int32:
// case ShaderAst::PrimitiveType::Int2:
// case ShaderAst::PrimitiveType::Int3:
// case ShaderAst::PrimitiveType::Int4:
case ShaderAst::PrimitiveType::UInt32:
// case ShaderAst::PrimitiveType::UInt2:
// case ShaderAst::PrimitiveType::UInt3:
// case ShaderAst::PrimitiveType::UInt4:
return SpirvOp::OpISub;
case ShaderAst::BasicType::Boolean:
case ShaderAst::BasicType::Sampler2D:
case ShaderAst::BasicType::Void:
case ShaderAst::PrimitiveType::Boolean:
// case ShaderAst::PrimitiveType::Sampler2D:
// case ShaderAst::PrimitiveType::Void:
break;
}
@ -127,28 +127,28 @@ namespace Nz
{
switch (leftType)
{
case ShaderAst::BasicType::Float1:
case ShaderAst::BasicType::Float2:
case ShaderAst::BasicType::Float3:
case ShaderAst::BasicType::Float4:
case ShaderAst::BasicType::Mat4x4:
case ShaderAst::PrimitiveType::Float32:
// case ShaderAst::PrimitiveType::Float2:
// case ShaderAst::PrimitiveType::Float3:
// case ShaderAst::PrimitiveType::Float4:
// case ShaderAst::PrimitiveType::Mat4x4:
return SpirvOp::OpFDiv;
case ShaderAst::BasicType::Int1:
case ShaderAst::BasicType::Int2:
case ShaderAst::BasicType::Int3:
case ShaderAst::BasicType::Int4:
case ShaderAst::PrimitiveType::Int32:
// case ShaderAst::PrimitiveType::Int2:
// case ShaderAst::PrimitiveType::Int3:
// case ShaderAst::PrimitiveType::Int4:
return SpirvOp::OpSDiv;
case ShaderAst::BasicType::UInt1:
case ShaderAst::BasicType::UInt2:
case ShaderAst::BasicType::UInt3:
case ShaderAst::BasicType::UInt4:
case ShaderAst::PrimitiveType::UInt32:
// case ShaderAst::PrimitiveType::UInt2:
// case ShaderAst::PrimitiveType::UInt3:
// case ShaderAst::PrimitiveType::UInt4:
return SpirvOp::OpUDiv;
case ShaderAst::BasicType::Boolean:
case ShaderAst::BasicType::Sampler2D:
case ShaderAst::BasicType::Void:
case ShaderAst::PrimitiveType::Boolean:
// case ShaderAst::PrimitiveType::Sampler2D:
// case ShaderAst::PrimitiveType::Void:
break;
}
@ -159,29 +159,29 @@ namespace Nz
{
switch (leftType)
{
case ShaderAst::BasicType::Boolean:
case ShaderAst::PrimitiveType::Boolean:
return SpirvOp::OpLogicalEqual;
case ShaderAst::BasicType::Float1:
case ShaderAst::BasicType::Float2:
case ShaderAst::BasicType::Float3:
case ShaderAst::BasicType::Float4:
case ShaderAst::BasicType::Mat4x4:
case ShaderAst::PrimitiveType::Float32:
// case ShaderAst::PrimitiveType::Float2:
// case ShaderAst::PrimitiveType::Float3:
// case ShaderAst::PrimitiveType::Float4:
// case ShaderAst::PrimitiveType::Mat4x4:
return SpirvOp::OpFOrdEqual;
case ShaderAst::BasicType::Int1:
case ShaderAst::BasicType::Int2:
case ShaderAst::BasicType::Int3:
case ShaderAst::BasicType::Int4:
case ShaderAst::BasicType::UInt1:
case ShaderAst::BasicType::UInt2:
case ShaderAst::BasicType::UInt3:
case ShaderAst::BasicType::UInt4:
case ShaderAst::PrimitiveType::Int32:
// case ShaderAst::PrimitiveType::Int2:
// case ShaderAst::PrimitiveType::Int3:
// case ShaderAst::PrimitiveType::Int4:
case ShaderAst::PrimitiveType::UInt32:
// case ShaderAst::PrimitiveType::UInt2:
// case ShaderAst::PrimitiveType::UInt3:
// case ShaderAst::PrimitiveType::UInt4:
return SpirvOp::OpIEqual;
case ShaderAst::BasicType::Sampler2D:
case ShaderAst::BasicType::Void:
break;
// case ShaderAst::PrimitiveType::Sampler2D:
// case ShaderAst::PrimitiveType::Void:
// break;
}
break;
@ -191,28 +191,28 @@ namespace Nz
{
switch (leftType)
{
case ShaderAst::BasicType::Float1:
case ShaderAst::BasicType::Float2:
case ShaderAst::BasicType::Float3:
case ShaderAst::BasicType::Float4:
case ShaderAst::BasicType::Mat4x4:
case ShaderAst::PrimitiveType::Float32:
// case ShaderAst::PrimitiveType::Float2:
// case ShaderAst::PrimitiveType::Float3:
// case ShaderAst::PrimitiveType::Float4:
// case ShaderAst::PrimitiveType::Mat4x4:
return SpirvOp::OpFOrdGreaterThan;
case ShaderAst::BasicType::Int1:
case ShaderAst::BasicType::Int2:
case ShaderAst::BasicType::Int3:
case ShaderAst::BasicType::Int4:
case ShaderAst::PrimitiveType::Int32:
// case ShaderAst::PrimitiveType::Int2:
// case ShaderAst::PrimitiveType::Int3:
// case ShaderAst::PrimitiveType::Int4:
return SpirvOp::OpSGreaterThan;
case ShaderAst::BasicType::UInt1:
case ShaderAst::BasicType::UInt2:
case ShaderAst::BasicType::UInt3:
case ShaderAst::BasicType::UInt4:
case ShaderAst::PrimitiveType::UInt32:
// case ShaderAst::PrimitiveType::UInt2:
// case ShaderAst::PrimitiveType::UInt3:
// case ShaderAst::PrimitiveType::UInt4:
return SpirvOp::OpUGreaterThan;
case ShaderAst::BasicType::Boolean:
case ShaderAst::BasicType::Sampler2D:
case ShaderAst::BasicType::Void:
case ShaderAst::PrimitiveType::Boolean:
// case ShaderAst::PrimitiveType::Sampler2D:
// case ShaderAst::PrimitiveType::Void:
break;
}
@ -223,28 +223,28 @@ namespace Nz
{
switch (leftType)
{
case ShaderAst::BasicType::Float1:
case ShaderAst::BasicType::Float2:
case ShaderAst::BasicType::Float3:
case ShaderAst::BasicType::Float4:
case ShaderAst::BasicType::Mat4x4:
case ShaderAst::PrimitiveType::Float32:
// case ShaderAst::PrimitiveType::Float2:
// case ShaderAst::PrimitiveType::Float3:
// case ShaderAst::PrimitiveType::Float4:
// case ShaderAst::PrimitiveType::Mat4x4:
return SpirvOp::OpFOrdGreaterThanEqual;
case ShaderAst::BasicType::Int1:
case ShaderAst::BasicType::Int2:
case ShaderAst::BasicType::Int3:
case ShaderAst::BasicType::Int4:
case ShaderAst::PrimitiveType::Int32:
// case ShaderAst::PrimitiveType::Int2:
// case ShaderAst::PrimitiveType::Int3:
// case ShaderAst::PrimitiveType::Int4:
return SpirvOp::OpSGreaterThanEqual;
case ShaderAst::BasicType::UInt1:
case ShaderAst::BasicType::UInt2:
case ShaderAst::BasicType::UInt3:
case ShaderAst::BasicType::UInt4:
case ShaderAst::PrimitiveType::UInt32:
// case ShaderAst::PrimitiveType::UInt2:
// case ShaderAst::PrimitiveType::UInt3:
// case ShaderAst::PrimitiveType::UInt4:
return SpirvOp::OpUGreaterThanEqual;
case ShaderAst::BasicType::Boolean:
case ShaderAst::BasicType::Sampler2D:
case ShaderAst::BasicType::Void:
case ShaderAst::PrimitiveType::Boolean:
// case ShaderAst::PrimitiveType::Sampler2D:
// case ShaderAst::PrimitiveType::Void:
break;
}
@ -255,28 +255,28 @@ namespace Nz
{
switch (leftType)
{
case ShaderAst::BasicType::Float1:
case ShaderAst::BasicType::Float2:
case ShaderAst::BasicType::Float3:
case ShaderAst::BasicType::Float4:
case ShaderAst::BasicType::Mat4x4:
case ShaderAst::PrimitiveType::Float32:
// case ShaderAst::PrimitiveType::Float2:
// case ShaderAst::PrimitiveType::Float3:
// case ShaderAst::PrimitiveType::Float4:
// case ShaderAst::PrimitiveType::Mat4x4:
return SpirvOp::OpFOrdLessThanEqual;
case ShaderAst::BasicType::Int1:
case ShaderAst::BasicType::Int2:
case ShaderAst::BasicType::Int3:
case ShaderAst::BasicType::Int4:
case ShaderAst::PrimitiveType::Int32:
// case ShaderAst::PrimitiveType::Int2:
// case ShaderAst::PrimitiveType::Int3:
// case ShaderAst::PrimitiveType::Int4:
return SpirvOp::OpSLessThanEqual;
case ShaderAst::BasicType::UInt1:
case ShaderAst::BasicType::UInt2:
case ShaderAst::BasicType::UInt3:
case ShaderAst::BasicType::UInt4:
case ShaderAst::PrimitiveType::UInt32:
// case ShaderAst::PrimitiveType::UInt2:
// case ShaderAst::PrimitiveType::UInt3:
// case ShaderAst::PrimitiveType::UInt4:
return SpirvOp::OpULessThanEqual;
case ShaderAst::BasicType::Boolean:
case ShaderAst::BasicType::Sampler2D:
case ShaderAst::BasicType::Void:
case ShaderAst::PrimitiveType::Boolean:
// case ShaderAst::PrimitiveType::Sampler2D:
// case ShaderAst::PrimitiveType::Void:
break;
}
@ -287,28 +287,28 @@ namespace Nz
{
switch (leftType)
{
case ShaderAst::BasicType::Float1:
case ShaderAst::BasicType::Float2:
case ShaderAst::BasicType::Float3:
case ShaderAst::BasicType::Float4:
case ShaderAst::BasicType::Mat4x4:
case ShaderAst::PrimitiveType::Float32:
// case ShaderAst::PrimitiveType::Float2:
// case ShaderAst::PrimitiveType::Float3:
// case ShaderAst::PrimitiveType::Float4:
// case ShaderAst::PrimitiveType::Mat4x4:
return SpirvOp::OpFOrdLessThan;
case ShaderAst::BasicType::Int1:
case ShaderAst::BasicType::Int2:
case ShaderAst::BasicType::Int3:
case ShaderAst::BasicType::Int4:
case ShaderAst::PrimitiveType::Int32:
// case ShaderAst::PrimitiveType::Int2:
// case ShaderAst::PrimitiveType::Int3:
// case ShaderAst::PrimitiveType::Int4:
return SpirvOp::OpSLessThan;
case ShaderAst::BasicType::UInt1:
case ShaderAst::BasicType::UInt2:
case ShaderAst::BasicType::UInt3:
case ShaderAst::BasicType::UInt4:
case ShaderAst::PrimitiveType::UInt32:
// case ShaderAst::PrimitiveType::UInt2:
// case ShaderAst::PrimitiveType::UInt3:
// case ShaderAst::PrimitiveType::UInt4:
return SpirvOp::OpULessThan;
case ShaderAst::BasicType::Boolean:
case ShaderAst::BasicType::Sampler2D:
case ShaderAst::BasicType::Void:
case ShaderAst::PrimitiveType::Boolean:
// case ShaderAst::PrimitiveType::Sampler2D:
// case ShaderAst::PrimitiveType::Void:
break;
}
@ -319,29 +319,29 @@ namespace Nz
{
switch (leftType)
{
case ShaderAst::BasicType::Boolean:
case ShaderAst::PrimitiveType::Boolean:
return SpirvOp::OpLogicalNotEqual;
case ShaderAst::BasicType::Float1:
case ShaderAst::BasicType::Float2:
case ShaderAst::BasicType::Float3:
case ShaderAst::BasicType::Float4:
case ShaderAst::BasicType::Mat4x4:
case ShaderAst::PrimitiveType::Float32:
// case ShaderAst::PrimitiveType::Float2:
// case ShaderAst::PrimitiveType::Float3:
// case ShaderAst::PrimitiveType::Float4:
// case ShaderAst::PrimitiveType::Mat4x4:
return SpirvOp::OpFOrdNotEqual;
case ShaderAst::BasicType::Int1:
case ShaderAst::BasicType::Int2:
case ShaderAst::BasicType::Int3:
case ShaderAst::BasicType::Int4:
case ShaderAst::BasicType::UInt1:
case ShaderAst::BasicType::UInt2:
case ShaderAst::BasicType::UInt3:
case ShaderAst::BasicType::UInt4:
case ShaderAst::PrimitiveType::Int32:
// case ShaderAst::PrimitiveType::Int2:
// case ShaderAst::PrimitiveType::Int3:
// case ShaderAst::PrimitiveType::Int4:
case ShaderAst::PrimitiveType::UInt32:
// case ShaderAst::PrimitiveType::UInt2:
// case ShaderAst::PrimitiveType::UInt3:
// case ShaderAst::PrimitiveType::UInt4:
return SpirvOp::OpINotEqual;
case ShaderAst::BasicType::Sampler2D:
case ShaderAst::BasicType::Void:
break;
// case ShaderAst::PrimitiveType::Sampler2D:
// case ShaderAst::PrimitiveType::Void:
// break;
}
break;
@ -351,22 +351,22 @@ namespace Nz
{
switch (leftType)
{
case ShaderAst::BasicType::Float1:
case ShaderAst::PrimitiveType::Float32:
{
switch (rightType)
{
case ShaderAst::BasicType::Float1:
case ShaderAst::PrimitiveType::Float32:
return SpirvOp::OpFMul;
case ShaderAst::BasicType::Float2:
case ShaderAst::BasicType::Float3:
case ShaderAst::BasicType::Float4:
swapOperands = true;
return SpirvOp::OpVectorTimesScalar;
case ShaderAst::BasicType::Mat4x4:
swapOperands = true;
return SpirvOp::OpMatrixTimesScalar;
// case ShaderAst::PrimitiveType::Float2:
// case ShaderAst::PrimitiveType::Float3:
// case ShaderAst::PrimitiveType::Float4:
// swapOperands = true;
// return SpirvOp::OpVectorTimesScalar;
//
// case ShaderAst::PrimitiveType::Mat4x4:
// swapOperands = true;
// return SpirvOp::OpMatrixTimesScalar;
default:
break;
@ -375,54 +375,54 @@ namespace Nz
break;
}
case ShaderAst::BasicType::Float2:
case ShaderAst::BasicType::Float3:
case ShaderAst::BasicType::Float4:
{
switch (rightType)
{
case ShaderAst::BasicType::Float1:
return SpirvOp::OpVectorTimesScalar;
// case ShaderAst::PrimitiveType::Float2:
// case ShaderAst::PrimitiveType::Float3:
// case ShaderAst::PrimitiveType::Float4:
// {
// switch (rightType)
// {
// case ShaderAst::PrimitiveType::Float32:
// return SpirvOp::OpVectorTimesScalar;
//
// case ShaderAst::PrimitiveType::Float2:
// case ShaderAst::PrimitiveType::Float3:
// case ShaderAst::PrimitiveType::Float4:
// return SpirvOp::OpFMul;
//
// case ShaderAst::PrimitiveType::Mat4x4:
// return SpirvOp::OpVectorTimesMatrix;
//
// default:
// break;
// }
//
// break;
// }
case ShaderAst::BasicType::Float2:
case ShaderAst::BasicType::Float3:
case ShaderAst::BasicType::Float4:
return SpirvOp::OpFMul;
case ShaderAst::BasicType::Mat4x4:
return SpirvOp::OpVectorTimesMatrix;
default:
break;
}
break;
}
case ShaderAst::BasicType::Int1:
case ShaderAst::BasicType::Int2:
case ShaderAst::BasicType::Int3:
case ShaderAst::BasicType::Int4:
case ShaderAst::BasicType::UInt1:
case ShaderAst::BasicType::UInt2:
case ShaderAst::BasicType::UInt3:
case ShaderAst::BasicType::UInt4:
case ShaderAst::PrimitiveType::Int32:
// case ShaderAst::PrimitiveType::Int2:
// case ShaderAst::PrimitiveType::Int3:
// case ShaderAst::PrimitiveType::Int4:
case ShaderAst::PrimitiveType::UInt32:
// case ShaderAst::PrimitiveType::UInt2:
// case ShaderAst::PrimitiveType::UInt3:
// case ShaderAst::PrimitiveType::UInt4:
return SpirvOp::OpIMul;
case ShaderAst::BasicType::Mat4x4:
{
switch (rightType)
{
case ShaderAst::BasicType::Float1: return SpirvOp::OpMatrixTimesScalar;
case ShaderAst::BasicType::Float4: return SpirvOp::OpMatrixTimesVector;
case ShaderAst::BasicType::Mat4x4: return SpirvOp::OpMatrixTimesMatrix;
default:
break;
}
break;
}
// case ShaderAst::PrimitiveType::Mat4x4:
// {
// switch (rightType)
// {
// case ShaderAst::PrimitiveType::Float32: return SpirvOp::OpMatrixTimesScalar;
// case ShaderAst::PrimitiveType::Float4: return SpirvOp::OpMatrixTimesVector;
// case ShaderAst::PrimitiveType::Mat4x4: return SpirvOp::OpMatrixTimesMatrix;
//
// default:
// break;
// }
//
// break;
// }
default:
break;
@ -501,10 +501,10 @@ namespace Nz
void SpirvAstVisitor::Visit(ShaderAst::CastExpression& node)
{
const ShaderAst::ShaderExpressionType& targetExprType = node.targetType;
assert(IsBasicType(targetExprType));
const ShaderAst::ExpressionType& targetExprType = node.targetType;
assert(IsPrimitiveType(targetExprType));
ShaderAst::BasicType targetType = std::get<ShaderAst::BasicType>(targetExprType);
ShaderAst::PrimitiveType targetType = std::get<ShaderAst::PrimitiveType>(targetExprType);
StackVector<UInt32> exprResults = NazaraStackVector(UInt32, node.expressions.size());
@ -582,12 +582,12 @@ namespace Nz
{
case ShaderAst::IntrinsicType::DotProduct:
{
ShaderAst::ShaderExpressionType vecExprType = GetExpressionType(*node.parameters[0], m_cache);
assert(IsBasicType(vecExprType));
ShaderAst::ExpressionType vecExprType = GetExpressionType(*node.parameters[0], m_cache);
assert(IsVectorType(vecExprType));
ShaderAst::BasicType vecType = std::get<ShaderAst::BasicType>(vecExprType);
const ShaderAst::VectorType& vecType = std::get<ShaderAst::VectorType>(vecExprType);
UInt32 typeId = m_writer.GetTypeId(ShaderAst::GetComponentType(vecType));
UInt32 typeId = m_writer.GetTypeId(vecType.type);
UInt32 vec1 = EvaluateExpression(node.parameters[0]);
UInt32 vec2 = EvaluateExpression(node.parameters[1]);
@ -626,10 +626,10 @@ namespace Nz
void SpirvAstVisitor::Visit(ShaderAst::SwizzleExpression& node)
{
ShaderAst::ShaderExpressionType targetExprType = ShaderAst::GetExpressionType(node, m_cache);
assert(IsBasicType(targetExprType));
ShaderAst::ExpressionType targetExprType = ShaderAst::GetExpressionType(node, m_cache);
assert(IsPrimitiveType(targetExprType));
ShaderAst::BasicType targetType = std::get<ShaderAst::BasicType>(targetExprType);
ShaderAst::PrimitiveType targetType = std::get<ShaderAst::PrimitiveType>(targetExprType);
UInt32 exprResultId = EvaluateExpression(node.expression);
UInt32 resultId = m_writer.AllocateResultId();

View File

@ -535,7 +535,7 @@ namespace Nz
else if constexpr (std::is_same_v<T, Vector2f> || std::is_same_v<T, Vector2i>)
{
return ConstantComposite{
BuildType((std::is_same_v<T, Vector2f>) ? ShaderAst::BasicType::Float2 : ShaderAst::BasicType::Int2),
BuildType(ShaderAst::VectorType{ 2, (std::is_same_v<T, Vector2f>) ? ShaderAst::PrimitiveType::Float32 : ShaderAst::PrimitiveType::Int32 }),
{
BuildConstant(arg.x),
BuildConstant(arg.y)
@ -545,7 +545,7 @@ namespace Nz
else if constexpr (std::is_same_v<T, Vector3f> || std::is_same_v<T, Vector3i>)
{
return ConstantComposite{
BuildType((std::is_same_v<T, Vector3f>) ? ShaderAst::BasicType::Float3 : ShaderAst::BasicType::Int3),
BuildType(ShaderAst::VectorType{ 3, (std::is_same_v<T, Vector3f>) ? ShaderAst::PrimitiveType::Float32 : ShaderAst::PrimitiveType::Int32 }),
{
BuildConstant(arg.x),
BuildConstant(arg.y),
@ -556,7 +556,7 @@ namespace Nz
else if constexpr (std::is_same_v<T, Vector4f> || std::is_same_v<T, Vector4i>)
{
return ConstantComposite{
BuildType((std::is_same_v<T, Vector4f>) ? ShaderAst::BasicType::Float4 : ShaderAst::BasicType::Int4),
BuildType(ShaderAst::VectorType{ 4, (std::is_same_v<T, Vector4f>) ? ShaderAst::PrimitiveType::Float32 : ShaderAst::PrimitiveType::Int32 }),
{
BuildConstant(arg.x),
BuildConstant(arg.y),
@ -570,7 +570,7 @@ namespace Nz
}, value));
}
auto SpirvConstantCache::BuildFunctionType(const ShaderAst::ShaderExpressionType& retType, const std::vector<ShaderAst::ShaderExpressionType>& parameters) -> TypePtr
auto SpirvConstantCache::BuildFunctionType(const ShaderAst::ExpressionType& retType, const std::vector<ShaderAst::ExpressionType>& parameters) -> TypePtr
{
std::vector<SpirvConstantCache::TypePtr> parameterTypes;
parameterTypes.reserve(parameters.size());
@ -584,7 +584,7 @@ namespace Nz
});
}
auto SpirvConstantCache::BuildPointerType(const ShaderAst::BasicType& type, SpirvStorageClass storageClass) -> TypePtr
auto SpirvConstantCache::BuildPointerType(const ShaderAst::ExpressionType& type, SpirvStorageClass storageClass) -> TypePtr
{
return std::make_shared<Type>(Pointer{
BuildType(type),
@ -592,85 +592,22 @@ namespace Nz
});
}
auto SpirvConstantCache::BuildPointerType(const ShaderAst::ShaderExpressionType& type, SpirvStorageClass storageClass) -> TypePtr
auto SpirvConstantCache::BuildPointerType(const ShaderAst::PrimitiveType& type, SpirvStorageClass storageClass) -> TypePtr
{
return std::make_shared<Type>(Pointer{
BuildType(type),
storageClass
});
});
}
auto SpirvConstantCache::BuildType(const ShaderAst::BasicType& type) -> TypePtr
{
return std::make_shared<Type>([&]() -> AnyType
{
switch (type)
{
case ShaderAst::BasicType::Boolean:
return Bool{};
case ShaderAst::BasicType::Float1:
return Float{ 32 };
case ShaderAst::BasicType::Int1:
return Integer{ 32, true };
case ShaderAst::BasicType::Float2:
case ShaderAst::BasicType::Float3:
case ShaderAst::BasicType::Float4:
case ShaderAst::BasicType::Int2:
case ShaderAst::BasicType::Int3:
case ShaderAst::BasicType::Int4:
case ShaderAst::BasicType::UInt2:
case ShaderAst::BasicType::UInt3:
case ShaderAst::BasicType::UInt4:
{
auto vecType = BuildType(ShaderAst::GetComponentType(type));
UInt32 componentCount = ShaderAst::GetComponentCount(type);
return Vector{ vecType, componentCount };
}
case ShaderAst::BasicType::Mat4x4:
return Matrix{ BuildType(ShaderAst::BasicType::Float4), 4u };
case ShaderAst::BasicType::UInt1:
return Integer{ 32, false };
case ShaderAst::BasicType::Void:
return Void{};
case ShaderAst::BasicType::Sampler2D:
{
auto imageType = Image{
{}, //< qualifier
{}, //< depth
{}, //< sampled
SpirvDim::Dim2D, //< dim
SpirvImageFormat::Unknown, //< format
BuildType(ShaderAst::BasicType::Float1), //< sampledType
false, //< arrayed,
false //< multisampled
};
return SampledImage{ std::make_shared<Type>(imageType) };
}
}
throw std::runtime_error("unexpected type");
}());
}
auto SpirvConstantCache::BuildType(const ShaderAst::ShaderExpressionType& type) -> TypePtr
auto SpirvConstantCache::BuildType(const ShaderAst::ExpressionType& type) -> TypePtr
{
return std::visit([&](auto&& arg) -> TypePtr
{
using T = std::decay_t<decltype(arg)>;
if constexpr (std::is_same_v<T, ShaderAst::BasicType>)
return BuildType(arg);
else if constexpr (std::is_same_v<T, std::string>)
return BuildType(arg);
/*else if constexpr (std::is_same_v<T, std::string>)
{
/*// Register struct members type
// Register struct members type
const auto& structs = shader.GetStructs();
auto it = std::find_if(structs.begin(), structs.end(), [&](const auto& s) { return s.name == arg; });
if (it == structs.end())
@ -688,14 +625,77 @@ namespace Nz
sMembers.type = BuildType(shader, member.type);
}
return std::make_shared<Type>(std::move(sType));*/
return std::make_shared<Type>(std::move(sType));
return nullptr;
}
else
static_assert(AlwaysFalse<T>::value, "non-exhaustive visitor");
static_assert(AlwaysFalse<T>::value, "non-exhaustive visitor");*/
}, type);
}
auto SpirvConstantCache::BuildType(const ShaderAst::IdentifierType& type) -> TypePtr
{
throw std::runtime_error("unexpected type");
}
auto SpirvConstantCache::BuildType(const ShaderAst::PrimitiveType& type) -> TypePtr
{
return std::make_shared<Type>([&]() -> AnyType
{
switch (type)
{
case ShaderAst::PrimitiveType::Boolean:
return Bool{};
case ShaderAst::PrimitiveType::Float32:
return Float{ 32 };
case ShaderAst::PrimitiveType::Int32:
return Integer{ 32, true };
}
throw std::runtime_error("unexpected type");
}());
}
auto SpirvConstantCache::BuildType(const ShaderAst::MatrixType& type) -> TypePtr
{
return std::make_shared<Type>(
Matrix{
BuildType(ShaderAst::VectorType {
UInt32(type.rowCount), type.type
}),
UInt32(type.columnCount)
});
}
auto SpirvConstantCache::BuildType(const ShaderAst::NoType& type) -> TypePtr
{
return std::make_shared<Type>(Void{});
}
auto SpirvConstantCache::BuildType(const ShaderAst::SamplerType& type) -> TypePtr
{
//TODO
auto imageType = Image{
{}, //< qualifier
{}, //< depth
{}, //< sampled
SpirvDim::Dim2D, //< dim
SpirvImageFormat::Unknown, //< format
BuildType(ShaderAst::PrimitiveType::Float32), //< sampledType
false, //< arrayed,
false //< multisampled
};
return std::make_shared<Type>(SampledImage{ std::make_shared<Type>(imageType) });
}
auto SpirvConstantCache::BuildType(const ShaderAst::VectorType& type) -> TypePtr
{
return std::make_shared<Type>(Vector{ BuildType(type.type), UInt32(type.componentCount) });
}
void SpirvConstantCache::Write(const AnyConstant& constant, UInt32 resultId, SpirvSection& constants)
{
std::visit([&](auto&& arg)

View File

@ -29,7 +29,7 @@ namespace Nz
{
public:
using ExtInstList = std::unordered_set<std::string>;
using LocalContainer = std::unordered_set<ShaderAst::ShaderExpressionType>;
using LocalContainer = std::unordered_set<ShaderAst::ExpressionType>;
using FunctionContainer = std::vector<std::reference_wrapper<ShaderAst::DeclareFunctionStatement>>;
PreVisitor(ShaderAst::AstCache* cache, const SpirvWriter::States& conditions, SpirvConstantCache& constantCache) :
@ -81,7 +81,7 @@ namespace Nz
{
funcs.emplace_back(node);
std::vector<ShaderAst::ShaderExpressionType> parameterTypes;
std::vector<ShaderAst::ExpressionType> parameterTypes;
for (auto& parameter : node.parameters)
parameterTypes.push_back(parameter.type);
@ -92,8 +92,17 @@ namespace Nz
void Visit(ShaderAst::DeclareStructStatement& node) override
{
for (auto& field : node.description.members)
m_constantCache.Register(*SpirvConstantCache::BuildType(field.type));
SpirvConstantCache::Structure sType;
sType.name = node.description.name;
for (const auto& [name, attribute, type] : node.description.members)
{
auto& sMembers = sType.members.emplace_back();
sMembers.name = name;
sMembers.type = SpirvConstantCache::BuildType(type);
}
m_constantCache.Register(SpirvConstantCache::Type{ std::move(sType) });
}
void Visit(ShaderAst::DeclareVariableStatement& node) override
@ -137,26 +146,26 @@ namespace Nz
};
template<typename T>
constexpr ShaderAst::BasicType GetBasicType()
constexpr ShaderAst::PrimitiveType GetBasicType()
{
if constexpr (std::is_same_v<T, bool>)
return ShaderAst::BasicType::Boolean;
return ShaderAst::PrimitiveType::Boolean;
else if constexpr (std::is_same_v<T, float>)
return(ShaderAst::BasicType::Float1);
return(ShaderAst::PrimitiveType::Float32);
else if constexpr (std::is_same_v<T, Int32>)
return(ShaderAst::BasicType::Int1);
return(ShaderAst::PrimitiveType::Int32);
else if constexpr (std::is_same_v<T, Vector2f>)
return(ShaderAst::BasicType::Float2);
return(ShaderAst::PrimitiveType::Float2);
else if constexpr (std::is_same_v<T, Vector3f>)
return(ShaderAst::BasicType::Float3);
return(ShaderAst::PrimitiveType::Float3);
else if constexpr (std::is_same_v<T, Vector4f>)
return(ShaderAst::BasicType::Float4);
return(ShaderAst::PrimitiveType::Float4);
else if constexpr (std::is_same_v<T, Vector2i32>)
return(ShaderAst::BasicType::Int2);
return(ShaderAst::PrimitiveType::Int2);
else if constexpr (std::is_same_v<T, Vector3i32>)
return(ShaderAst::BasicType::Int3);
return(ShaderAst::PrimitiveType::Int3);
else if constexpr (std::is_same_v<T, Vector4i32>)
return(ShaderAst::BasicType::Int4);
return(ShaderAst::PrimitiveType::Int4);
else
static_assert(AlwaysFalse<T>::value, "unhandled type");
}
@ -394,7 +403,7 @@ namespace Nz
if (!state.functionBlocks.back().IsTerminated())
{
assert(func.returnType == ShaderAst::ShaderExpressionType(ShaderAst::BasicType::Void));
assert(func.returnType == ShaderAst::ExpressionType{ ShaderAst::NoType{} });
state.functionBlocks.back().Append(SpirvOp::OpReturn);
}
@ -537,12 +546,12 @@ namespace Nz
return it.value();
}
UInt32 SpirvWriter::GetPointerTypeId(const ShaderAst::ShaderExpressionType& type, SpirvStorageClass storageClass) const
UInt32 SpirvWriter::GetPointerTypeId(const ShaderAst::ExpressionType& type, SpirvStorageClass storageClass) const
{
return m_currentState->constantTypeCache.GetId(*SpirvConstantCache::BuildPointerType(type, storageClass));
}
UInt32 SpirvWriter::GetTypeId(const ShaderAst::ShaderExpressionType& type) const
UInt32 SpirvWriter::GetTypeId(const ShaderAst::ExpressionType& type) const
{
return m_currentState->constantTypeCache.GetId(*SpirvConstantCache::BuildType(type));
}
@ -643,12 +652,12 @@ namespace Nz
return m_currentState->constantTypeCache.Register({ *BuildFunctionType(functionNode) });
}
UInt32 SpirvWriter::RegisterPointerType(ShaderAst::ShaderExpressionType type, SpirvStorageClass storageClass)
UInt32 SpirvWriter::RegisterPointerType(ShaderAst::ExpressionType type, SpirvStorageClass storageClass)
{
return m_currentState->constantTypeCache.Register(*SpirvConstantCache::BuildPointerType(type, storageClass));
}
UInt32 SpirvWriter::RegisterType(ShaderAst::ShaderExpressionType type)
UInt32 SpirvWriter::RegisterType(ShaderAst::ExpressionType type)
{
assert(m_currentState);
return m_currentState->constantTypeCache.Register(*SpirvConstantCache::BuildType(type));
@ -662,7 +671,7 @@ namespace Nz
SpirvConstantCache::TypePtr SpirvWriter::BuildFunctionType(const ShaderAst::DeclareFunctionStatement& functionNode)
{
std::vector<ShaderAst::ShaderExpressionType> parameterTypes;
std::vector<ShaderAst::ExpressionType> parameterTypes;
parameterTypes.reserve(functionNode.parameters.size());
for (const auto& parameter : functionNode.parameters)