Shader: Add support for custom functions calls (and better handle intrinsics)
This commit is contained in:
parent
8a6f0db034
commit
f6fd996bf1
|
|
@ -60,7 +60,6 @@ fn main(input: VertOut) -> FragOut
|
|||
let position = positionTexture.Sample(input.uv).xyz;
|
||||
|
||||
let distance = length(lightParameters.position - position);
|
||||
let attenuation = 1.0 / (lightParameters.constant + lightParameters.linear * distance + lightParameters.quadratic * (distance * distance));
|
||||
|
||||
let posToLight = (lightParameters.position - position) / distance;
|
||||
let lambert = dot(normal, posToLight);
|
||||
|
|
@ -68,6 +67,7 @@ fn main(input: VertOut) -> FragOut
|
|||
let curAngle = dot(lightParameters.direction, -posToLight);
|
||||
let innerMinusOuterAngle = lightParameters.innerAngle - lightParameters.outerAngle;
|
||||
|
||||
let attenuation = compute_attenuation(distance);
|
||||
attenuation = attenuation * max((curAngle - lightParameters.outerAngle) / innerMinusOuterAngle, 0.0);
|
||||
|
||||
let output: FragOut;
|
||||
|
|
@ -85,3 +85,8 @@ fn main(input: VertIn) -> VertOut
|
|||
|
||||
return output;
|
||||
}
|
||||
|
||||
fn compute_attenuation(distance: f32) -> f32
|
||||
{
|
||||
return 1.0 / (lightParameters.constant + lightParameters.linear * distance + lightParameters.quadratic * (distance * distance));
|
||||
}
|
||||
|
|
|
|||
|
|
@ -40,6 +40,8 @@ namespace Nz::ShaderAst
|
|||
virtual ExpressionPtr Clone(AccessMemberIndexExpression& node);
|
||||
virtual ExpressionPtr Clone(AssignExpression& node);
|
||||
virtual ExpressionPtr Clone(BinaryExpression& node);
|
||||
virtual ExpressionPtr Clone(CallFunctionExpression& node);
|
||||
virtual ExpressionPtr Clone(CallMethodExpression& node);
|
||||
virtual ExpressionPtr Clone(CastExpression& node);
|
||||
virtual ExpressionPtr Clone(ConditionalExpression& node);
|
||||
virtual ExpressionPtr Clone(ConstantExpression& node);
|
||||
|
|
|
|||
|
|
@ -30,6 +30,8 @@ NAZARA_SHADERAST_EXPRESSION(AccessMemberIdentifierExpression)
|
|||
NAZARA_SHADERAST_EXPRESSION(AccessMemberIndexExpression)
|
||||
NAZARA_SHADERAST_EXPRESSION(AssignExpression)
|
||||
NAZARA_SHADERAST_EXPRESSION(BinaryExpression)
|
||||
NAZARA_SHADERAST_EXPRESSION(CallFunctionExpression)
|
||||
NAZARA_SHADERAST_EXPRESSION(CallMethodExpression)
|
||||
NAZARA_SHADERAST_EXPRESSION(CastExpression)
|
||||
NAZARA_SHADERAST_EXPRESSION(ConditionalExpression)
|
||||
NAZARA_SHADERAST_EXPRESSION(ConstantExpression)
|
||||
|
|
|
|||
|
|
@ -24,6 +24,8 @@ namespace Nz::ShaderAst
|
|||
void Visit(AccessMemberIndexExpression& node) override;
|
||||
void Visit(AssignExpression& node) override;
|
||||
void Visit(BinaryExpression& node) override;
|
||||
void Visit(CallFunctionExpression& node) override;
|
||||
void Visit(CallMethodExpression& node) override;
|
||||
void Visit(CastExpression& node) override;
|
||||
void Visit(ConditionalExpression& node) override;
|
||||
void Visit(ConstantExpression& node) override;
|
||||
|
|
|
|||
|
|
@ -27,6 +27,8 @@ namespace Nz::ShaderAst
|
|||
void Serialize(AccessMemberIndexExpression& node);
|
||||
void Serialize(AssignExpression& node);
|
||||
void Serialize(BinaryExpression& node);
|
||||
void Serialize(CallFunctionExpression& node);
|
||||
void Serialize(CallMethodExpression& node);
|
||||
void Serialize(CastExpression& node);
|
||||
void Serialize(ConditionalExpression& node);
|
||||
void Serialize(ConstantExpression& node);
|
||||
|
|
|
|||
|
|
@ -35,6 +35,8 @@ namespace Nz::ShaderAst
|
|||
void Visit(AccessMemberIndexExpression& node) override;
|
||||
void Visit(AssignExpression& node) override;
|
||||
void Visit(BinaryExpression& node) override;
|
||||
void Visit(CallFunctionExpression& node) override;
|
||||
void Visit(CallMethodExpression& node) override;
|
||||
void Visit(CastExpression& node) override;
|
||||
void Visit(ConditionalExpression& node) override;
|
||||
void Visit(ConstantExpression& node) override;
|
||||
|
|
|
|||
|
|
@ -102,6 +102,25 @@ namespace Nz::ShaderAst
|
|||
ExpressionPtr right;
|
||||
};
|
||||
|
||||
struct NAZARA_SHADER_API CallFunctionExpression : public Expression
|
||||
{
|
||||
NodeType GetType() const override;
|
||||
void Visit(AstExpressionVisitor& visitor) override;
|
||||
|
||||
std::variant<std::string, std::size_t> targetFunction;
|
||||
std::vector<ExpressionPtr> parameters;
|
||||
};
|
||||
|
||||
struct NAZARA_SHADER_API CallMethodExpression : public Expression
|
||||
{
|
||||
NodeType GetType() const override;
|
||||
void Visit(AstExpressionVisitor& visitor) override;
|
||||
|
||||
ExpressionPtr object;
|
||||
std::string methodName;
|
||||
std::vector<ExpressionPtr> parameters;
|
||||
};
|
||||
|
||||
struct NAZARA_SHADER_API CastExpression : public Expression
|
||||
{
|
||||
NodeType GetType() const override;
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@
|
|||
#include <Nazara/Prerequisites.hpp>
|
||||
#include <Nazara/Shader/Config.hpp>
|
||||
#include <Nazara/Shader/Ast/AstCloner.hpp>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
|
|
@ -20,7 +21,7 @@ namespace Nz::ShaderAst
|
|||
public:
|
||||
struct Options;
|
||||
|
||||
inline SanitizeVisitor();
|
||||
SanitizeVisitor() = default;
|
||||
SanitizeVisitor(const SanitizeVisitor&) = delete;
|
||||
SanitizeVisitor(SanitizeVisitor&&) = delete;
|
||||
~SanitizeVisitor() = default;
|
||||
|
|
@ -47,6 +48,7 @@ namespace Nz::ShaderAst
|
|||
ExpressionPtr Clone(AccessMemberIdentifierExpression& node) override;
|
||||
ExpressionPtr Clone(AssignExpression& node) override;
|
||||
ExpressionPtr Clone(BinaryExpression& node) override;
|
||||
ExpressionPtr Clone(CallFunctionExpression& node) override;
|
||||
ExpressionPtr Clone(CastExpression& node) override;
|
||||
ExpressionPtr Clone(ConditionalExpression& node) override;
|
||||
ExpressionPtr Clone(ConstantExpression& node) override;
|
||||
|
|
@ -76,10 +78,11 @@ namespace Nz::ShaderAst
|
|||
void PushScope();
|
||||
void PopScope();
|
||||
|
||||
inline std::size_t RegisterFunction(std::string name);
|
||||
inline std::size_t RegisterOption(std::string name, ExpressionType type);
|
||||
inline std::size_t RegisterStruct(std::string name, StructDescription description);
|
||||
inline std::size_t RegisterVariable(std::string name, ExpressionType type);
|
||||
std::size_t RegisterFunction(DeclareFunctionStatement* funcDecl);
|
||||
std::size_t RegisterIntrinsic(std::string name, IntrinsicType type);
|
||||
std::size_t RegisterOption(std::string name, ExpressionType type);
|
||||
std::size_t RegisterStruct(std::string name, StructDescription description);
|
||||
std::size_t RegisterVariable(std::string name, ExpressionType type);
|
||||
|
||||
std::size_t ResolveStruct(const ExpressionType& exprType);
|
||||
std::size_t ResolveStruct(const IdentifierType& identifierType);
|
||||
|
|
@ -89,37 +92,33 @@ namespace Nz::ShaderAst
|
|||
|
||||
void SanitizeIdentifier(std::string& identifier);
|
||||
|
||||
struct Alias
|
||||
{
|
||||
std::variant<ExpressionType> value;
|
||||
};
|
||||
|
||||
struct Option
|
||||
{
|
||||
std::size_t optionIndex;
|
||||
};
|
||||
|
||||
struct Struct
|
||||
{
|
||||
std::size_t structIndex;
|
||||
};
|
||||
|
||||
struct Variable
|
||||
{
|
||||
std::size_t varIndex;
|
||||
};
|
||||
void Validate(CallFunctionExpression& node, const DeclareFunctionStatement* referenceDeclaration);
|
||||
void Validate(IntrinsicExpression& node);
|
||||
|
||||
struct Identifier
|
||||
{
|
||||
enum class Type
|
||||
{
|
||||
Alias,
|
||||
Function,
|
||||
Intrinsic,
|
||||
Option,
|
||||
Struct,
|
||||
Variable
|
||||
};
|
||||
|
||||
std::string name;
|
||||
std::variant<Alias, Option, Struct, Variable> value;
|
||||
std::size_t index;
|
||||
Type type;
|
||||
};
|
||||
|
||||
std::size_t m_nextFuncIndex;
|
||||
std::unordered_map<std::string /*functionName*/, std::pair<const DeclareFunctionStatement*, std::size_t>> m_functionDeclarations;
|
||||
std::vector<Identifier> m_identifiersInScope;
|
||||
std::vector<DeclareFunctionStatement*> m_functions;
|
||||
std::vector<IntrinsicType> m_intrinsics;
|
||||
std::vector<ExpressionType> m_options;
|
||||
std::vector<StructDescription> m_structs;
|
||||
std::vector<ExpressionType> m_variables;
|
||||
std::vector<ExpressionType> m_variableTypes;
|
||||
std::vector<std::size_t> m_scopeSizes;
|
||||
|
||||
struct Context;
|
||||
|
|
|
|||
|
|
@ -7,11 +7,6 @@
|
|||
|
||||
namespace Nz::ShaderAst
|
||||
{
|
||||
inline SanitizeVisitor::SanitizeVisitor() :
|
||||
m_nextFuncIndex(0)
|
||||
{
|
||||
}
|
||||
|
||||
inline StatementPtr SanitizeVisitor::Sanitize(const StatementPtr& statement, std::string* error)
|
||||
{
|
||||
return Sanitize(statement, {}, error);
|
||||
|
|
@ -26,56 +21,6 @@ namespace Nz::ShaderAst
|
|||
return &*it;
|
||||
}
|
||||
|
||||
inline std::size_t SanitizeVisitor::RegisterFunction(std::string name)
|
||||
{
|
||||
return m_nextFuncIndex++;
|
||||
}
|
||||
|
||||
inline std::size_t SanitizeVisitor::RegisterOption(std::string name, ExpressionType type)
|
||||
{
|
||||
std::size_t optionIndex = m_options.size();
|
||||
m_options.emplace_back(std::move(type));
|
||||
|
||||
m_identifiersInScope.push_back({
|
||||
std::move(name),
|
||||
Option {
|
||||
optionIndex
|
||||
}
|
||||
});
|
||||
|
||||
return optionIndex;
|
||||
}
|
||||
|
||||
inline std::size_t SanitizeVisitor::RegisterStruct(std::string name, StructDescription description)
|
||||
{
|
||||
std::size_t structIndex = m_structs.size();
|
||||
m_structs.emplace_back(std::move(description));
|
||||
|
||||
m_identifiersInScope.push_back({
|
||||
std::move(name),
|
||||
Struct {
|
||||
structIndex
|
||||
}
|
||||
});
|
||||
|
||||
return structIndex;
|
||||
}
|
||||
|
||||
inline std::size_t SanitizeVisitor::RegisterVariable(std::string name, ExpressionType type)
|
||||
{
|
||||
std::size_t varIndex = m_variables.size();
|
||||
m_variables.emplace_back(std::move(type));
|
||||
|
||||
m_identifiersInScope.push_back({
|
||||
std::move(name),
|
||||
Variable {
|
||||
varIndex
|
||||
}
|
||||
});
|
||||
|
||||
return varIndex;
|
||||
}
|
||||
|
||||
inline StatementPtr Sanitize(const StatementPtr& ast, std::string* error)
|
||||
{
|
||||
SanitizeVisitor sanitizer;
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@
|
|||
#include <set>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace Nz
|
||||
{
|
||||
|
|
@ -60,8 +61,9 @@ namespace Nz
|
|||
template<typename T> void Append(const T& param);
|
||||
template<typename T1, typename T2, typename... Args> void Append(const T1& firstParam, const T2& secondParam, Args&&... params);
|
||||
void AppendCommentSection(const std::string& section);
|
||||
void AppendFunctionDeclaration(const ShaderAst::DeclareFunctionStatement& node, bool forward = false);
|
||||
void AppendField(std::size_t structIndex, const std::size_t* memberIndices, std::size_t remainingMembers);
|
||||
void AppendHeader();
|
||||
void AppendHeader(const std::vector<ShaderAst::DeclareFunctionStatement*>& forwardFunctionDeclarations);
|
||||
void AppendLine(const std::string& txt = {});
|
||||
template<typename... Args> void AppendLine(Args&&... params);
|
||||
void AppendStatementList(std::vector<ShaderAst::StatementPtr>& statements);
|
||||
|
|
@ -72,6 +74,7 @@ namespace Nz
|
|||
void HandleEntryPoint(ShaderAst::DeclareFunctionStatement& node);
|
||||
void HandleInOut();
|
||||
|
||||
void RegisterFunction(std::size_t funcIndex, std::string funcName);
|
||||
void RegisterStruct(std::size_t structIndex, ShaderAst::StructDescription desc);
|
||||
void RegisterVariable(std::size_t varIndex, std::string varName);
|
||||
|
||||
|
|
@ -80,6 +83,7 @@ namespace Nz
|
|||
void Visit(ShaderAst::AccessMemberIndexExpression& node) override;
|
||||
void Visit(ShaderAst::AssignExpression& node) override;
|
||||
void Visit(ShaderAst::BinaryExpression& node) override;
|
||||
void Visit(ShaderAst::CallFunctionExpression& node) override;
|
||||
void Visit(ShaderAst::CastExpression& node) override;
|
||||
void Visit(ShaderAst::ConditionalExpression& node) override;
|
||||
void Visit(ShaderAst::ConstantExpression& node) override;
|
||||
|
|
|
|||
|
|
@ -38,6 +38,11 @@ namespace Nz::ShaderBuilder
|
|||
inline std::unique_ptr<ShaderAst::BranchStatement> operator()(std::vector<ShaderAst::BranchStatement::ConditionalStatement> condStatements, ShaderAst::StatementPtr elseStatement = nullptr) const;
|
||||
};
|
||||
|
||||
struct CallFunction
|
||||
{
|
||||
inline std::unique_ptr<ShaderAst::CallFunctionExpression> operator()(std::string functionName, std::vector<ShaderAst::ExpressionPtr> parameters) const;
|
||||
};
|
||||
|
||||
struct Cast
|
||||
{
|
||||
inline std::unique_ptr<ShaderAst::CastExpression> operator()(ShaderAst::ExpressionType targetType, std::array<ShaderAst::ExpressionPtr, 4> expressions) const;
|
||||
|
|
@ -133,6 +138,7 @@ namespace Nz::ShaderBuilder
|
|||
constexpr Impl::Assign Assign;
|
||||
constexpr Impl::Binary Binary;
|
||||
constexpr Impl::Branch Branch;
|
||||
constexpr Impl::CallFunction CallFunction;
|
||||
constexpr Impl::Cast Cast;
|
||||
constexpr Impl::ConditionalExpression ConditionalExpression;
|
||||
constexpr Impl::ConditionalStatement ConditionalStatement;
|
||||
|
|
|
|||
|
|
@ -58,6 +58,15 @@ namespace Nz::ShaderBuilder
|
|||
return branchNode;
|
||||
}
|
||||
|
||||
inline std::unique_ptr<ShaderAst::CallFunctionExpression> Impl::CallFunction::operator()(std::string functionName, std::vector<ShaderAst::ExpressionPtr> parameters) const
|
||||
{
|
||||
auto callFunctionExpression = std::make_unique<ShaderAst::CallFunctionExpression>();
|
||||
callFunctionExpression->targetFunction = std::move(functionName);
|
||||
callFunctionExpression->parameters = std::move(parameters);
|
||||
|
||||
return callFunctionExpression;
|
||||
}
|
||||
|
||||
inline std::unique_ptr<ShaderAst::CastExpression> Impl::Cast::operator()(ShaderAst::ExpressionType targetType, std::array<ShaderAst::ExpressionPtr, 4> expressions) const
|
||||
{
|
||||
auto castNode = std::make_unique<ShaderAst::CastExpression>();
|
||||
|
|
@ -138,7 +147,7 @@ namespace Nz::ShaderBuilder
|
|||
return declareFunctionNode;
|
||||
}
|
||||
|
||||
inline std::unique_ptr<ShaderAst::DeclareOptionStatement> Nz::ShaderBuilder::Impl::DeclareOption::operator()(std::string name, ShaderAst::ExpressionType type, ShaderAst::ExpressionPtr initialValue) const
|
||||
inline std::unique_ptr<ShaderAst::DeclareOptionStatement> Impl::DeclareOption::operator()(std::string name, ShaderAst::ExpressionType type, ShaderAst::ExpressionPtr initialValue) const
|
||||
{
|
||||
auto declareOptionNode = std::make_unique<ShaderAst::DeclareOptionStatement>();
|
||||
declareOptionNode->optName = std::move(name);
|
||||
|
|
@ -156,7 +165,7 @@ namespace Nz::ShaderBuilder
|
|||
return declareStructNode;
|
||||
}
|
||||
|
||||
inline std::unique_ptr<ShaderAst::DeclareVariableStatement> Nz::ShaderBuilder::Impl::DeclareVariable::operator()(std::string name, ShaderAst::ExpressionPtr initialValue) const
|
||||
inline std::unique_ptr<ShaderAst::DeclareVariableStatement> Impl::DeclareVariable::operator()(std::string name, ShaderAst::ExpressionPtr initialValue) const
|
||||
{
|
||||
auto declareVariableNode = std::make_unique<ShaderAst::DeclareVariableStatement>();
|
||||
declareVariableNode->varName = std::move(name);
|
||||
|
|
@ -165,7 +174,7 @@ namespace Nz::ShaderBuilder
|
|||
return declareVariableNode;
|
||||
}
|
||||
|
||||
inline std::unique_ptr<ShaderAst::DeclareVariableStatement> Nz::ShaderBuilder::Impl::DeclareVariable::operator()(std::string name, ShaderAst::ExpressionType type, ShaderAst::ExpressionPtr initialValue) const
|
||||
inline std::unique_ptr<ShaderAst::DeclareVariableStatement> Impl::DeclareVariable::operator()(std::string name, ShaderAst::ExpressionType type, ShaderAst::ExpressionPtr initialValue) const
|
||||
{
|
||||
auto declareVariableNode = std::make_unique<ShaderAst::DeclareVariableStatement>();
|
||||
declareVariableNode->varName = std::move(name);
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@
|
|||
#include <Nazara/Shader/ShaderLangLexer.hpp>
|
||||
#include <Nazara/Shader/Ast/Nodes.hpp>
|
||||
#include <filesystem>
|
||||
#include <optional>
|
||||
|
||||
namespace Nz::ShaderLang
|
||||
{
|
||||
|
|
@ -69,7 +70,7 @@ namespace Nz::ShaderLang
|
|||
// Flow control
|
||||
const Token& Advance();
|
||||
void Consume(std::size_t count = 1);
|
||||
ShaderAst::ExpressionType DecodeType(const std::string& identifier);
|
||||
std::optional<ShaderAst::ExpressionType> DecodeType(const std::string& identifier);
|
||||
void EnterScope();
|
||||
const Token& Expect(const Token& token, TokenType type);
|
||||
const Token& ExpectNot(const Token& token, TokenType type);
|
||||
|
|
|
|||
|
|
@ -45,6 +45,7 @@ namespace Nz
|
|||
void Visit(ShaderAst::AssignExpression& node) override;
|
||||
void Visit(ShaderAst::BinaryExpression& node) override;
|
||||
void Visit(ShaderAst::BranchStatement& node) override;
|
||||
void Visit(ShaderAst::CallFunctionExpression& node) override;
|
||||
void Visit(ShaderAst::CastExpression& node) override;
|
||||
void Visit(ShaderAst::ConditionalExpression& node) override;
|
||||
void Visit(ShaderAst::ConditionalStatement& node) override;
|
||||
|
|
@ -92,7 +93,6 @@ namespace Nz
|
|||
ShaderStageType stageType;
|
||||
std::optional<InputStruct> inputStruct;
|
||||
std::optional<UInt32> outputStructTypeId;
|
||||
std::size_t funcIndex;
|
||||
std::vector<Input> inputs;
|
||||
std::vector<Output> outputs;
|
||||
};
|
||||
|
|
@ -101,6 +101,11 @@ namespace Nz
|
|||
{
|
||||
std::optional<EntryPoint> entryPointData;
|
||||
|
||||
struct FuncCall
|
||||
{
|
||||
std::size_t firstVarIndex;
|
||||
};
|
||||
|
||||
struct Parameter
|
||||
{
|
||||
UInt32 pointerTypeId;
|
||||
|
|
@ -113,7 +118,9 @@ namespace Nz
|
|||
UInt32 varId;
|
||||
};
|
||||
|
||||
std::size_t funcIndex;
|
||||
std::string name;
|
||||
std::vector<FuncCall> funcCalls;
|
||||
std::vector<Parameter> parameters;
|
||||
std::vector<Variable> variables;
|
||||
std::unordered_map<std::size_t, std::size_t> varIndexToVarId;
|
||||
|
|
@ -138,6 +145,7 @@ namespace Nz
|
|||
inline void RegisterVariable(std::size_t varIndex, UInt32 typeId, UInt32 pointerId, SpirvStorageClass storageClass);
|
||||
|
||||
std::size_t m_extVarIndex;
|
||||
std::size_t m_funcCallIndex;
|
||||
std::size_t m_funcIndex;
|
||||
std::vector<std::size_t> m_scopeSizes;
|
||||
std::vector<FuncData>& m_funcData;
|
||||
|
|
|
|||
|
|
@ -202,6 +202,36 @@ namespace Nz::ShaderAst
|
|||
return clone;
|
||||
}
|
||||
|
||||
ExpressionPtr AstCloner::Clone(CallFunctionExpression& node)
|
||||
{
|
||||
auto clone = std::make_unique<CallFunctionExpression>();
|
||||
clone->targetFunction = node.targetFunction;
|
||||
|
||||
clone->parameters.reserve(node.parameters.size());
|
||||
for (auto& parameter : node.parameters)
|
||||
clone->parameters.push_back(CloneExpression(parameter));
|
||||
|
||||
clone->cachedExpressionType = node.cachedExpressionType;
|
||||
|
||||
return clone;
|
||||
}
|
||||
|
||||
ExpressionPtr AstCloner::Clone(CallMethodExpression& node)
|
||||
{
|
||||
auto clone = std::make_unique<CallMethodExpression>();
|
||||
clone->methodName = node.methodName;
|
||||
|
||||
clone->object = CloneExpression(node.object);
|
||||
|
||||
clone->parameters.reserve(node.parameters.size());
|
||||
for (auto& parameter : node.parameters)
|
||||
clone->parameters.push_back(CloneExpression(parameter));
|
||||
|
||||
clone->cachedExpressionType = node.cachedExpressionType;
|
||||
|
||||
return clone;
|
||||
}
|
||||
|
||||
ExpressionPtr AstCloner::Clone(CastExpression& node)
|
||||
{
|
||||
auto clone = std::make_unique<CastExpression>();
|
||||
|
|
|
|||
|
|
@ -29,6 +29,20 @@ namespace Nz::ShaderAst
|
|||
node.right->Visit(*this);
|
||||
}
|
||||
|
||||
void AstRecursiveVisitor::Visit(CallFunctionExpression& node)
|
||||
{
|
||||
for (auto& param : node.parameters)
|
||||
param->Visit(*this);
|
||||
}
|
||||
|
||||
void AstRecursiveVisitor::Visit(CallMethodExpression& node)
|
||||
{
|
||||
node.object->Visit(*this);
|
||||
|
||||
for (auto& param : node.parameters)
|
||||
param->Visit(*this);
|
||||
}
|
||||
|
||||
void AstRecursiveVisitor::Visit(CastExpression& node)
|
||||
{
|
||||
for (auto& expr : node.expressions)
|
||||
|
|
|
|||
|
|
@ -65,6 +65,45 @@ namespace Nz::ShaderAst
|
|||
Node(node.right);
|
||||
}
|
||||
|
||||
void AstSerializerBase::Serialize(CallFunctionExpression& node)
|
||||
{
|
||||
UInt32 typeIndex;
|
||||
if (IsWriting())
|
||||
typeIndex = UInt32(node.targetFunction.index());
|
||||
|
||||
Value(typeIndex);
|
||||
|
||||
// Waiting for template lambda in C++20
|
||||
auto SerializeValue = [&](auto dummyType)
|
||||
{
|
||||
using T = std::decay_t<decltype(dummyType)>;
|
||||
|
||||
auto& value = (IsWriting()) ? std::get<T>(node.targetFunction) : node.targetFunction.emplace<T>();
|
||||
Value(value);
|
||||
};
|
||||
|
||||
static_assert(std::variant_size_v<decltype(node.targetFunction)> == 2);
|
||||
switch (typeIndex)
|
||||
{
|
||||
case 0: SerializeValue(std::string()); break;
|
||||
case 1: SerializeValue(std::size_t()); break;
|
||||
}
|
||||
|
||||
Container(node.parameters);
|
||||
for (auto& param : node.parameters)
|
||||
Node(param);
|
||||
}
|
||||
|
||||
void AstSerializerBase::Serialize(CallMethodExpression& node)
|
||||
{
|
||||
Node(node.object);
|
||||
Value(node.methodName);
|
||||
|
||||
Container(node.parameters);
|
||||
for (auto& param : node.parameters)
|
||||
Node(param);
|
||||
}
|
||||
|
||||
void AstSerializerBase::Serialize(CastExpression& node)
|
||||
{
|
||||
Type(node.targetType);
|
||||
|
|
|
|||
|
|
@ -33,6 +33,16 @@ namespace Nz::ShaderAst
|
|||
m_expressionCategory = ExpressionCategory::RValue;
|
||||
}
|
||||
|
||||
void ShaderAstValueCategory::Visit(CallFunctionExpression& /*node*/)
|
||||
{
|
||||
m_expressionCategory = ExpressionCategory::RValue;
|
||||
}
|
||||
|
||||
void ShaderAstValueCategory::Visit(CallMethodExpression& /*node*/)
|
||||
{
|
||||
m_expressionCategory = ExpressionCategory::RValue;
|
||||
}
|
||||
|
||||
void ShaderAstValueCategory::Visit(CastExpression& /*node*/)
|
||||
{
|
||||
m_expressionCategory = ExpressionCategory::RValue;
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@
|
|||
#include <Nazara/Core/CallOnExit.hpp>
|
||||
#include <Nazara/Core/StackArray.hpp>
|
||||
#include <Nazara/Shader/ShaderBuilder.hpp>
|
||||
#include <Nazara/Shader/Ast/AstRecursiveVisitor.hpp>
|
||||
#include <Nazara/Shader/Ast/AstUtils.hpp>
|
||||
#include <stdexcept>
|
||||
#include <unordered_set>
|
||||
|
|
@ -47,6 +48,28 @@ namespace Nz::ShaderAst
|
|||
|
||||
PushScope(); //< Global scope
|
||||
{
|
||||
RegisterIntrinsic("cross", IntrinsicType::CrossProduct);
|
||||
RegisterIntrinsic("dot", IntrinsicType::DotProduct);
|
||||
RegisterIntrinsic("max", IntrinsicType::Max);
|
||||
RegisterIntrinsic("min", IntrinsicType::Min);
|
||||
RegisterIntrinsic("length", IntrinsicType::Length);
|
||||
|
||||
// Collect function name and their types
|
||||
if (nodePtr->GetType() == NodeType::MultiStatement)
|
||||
{
|
||||
std::size_t functionIndex = 0;
|
||||
|
||||
const MultiStatement& multiStatement = static_cast<const MultiStatement&>(*nodePtr);
|
||||
for (const auto& statementPtr : multiStatement.statements)
|
||||
{
|
||||
if (statementPtr->GetType() == NodeType::DeclareFunctionStatement)
|
||||
{
|
||||
const DeclareFunctionStatement& funcDeclaration = static_cast<const DeclareFunctionStatement&>(*statementPtr);
|
||||
m_functionDeclarations.emplace(funcDeclaration.name, std::make_pair(&funcDeclaration, functionIndex++));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
clone = AstCloner::Clone(nodePtr);
|
||||
|
|
@ -355,6 +378,71 @@ namespace Nz::ShaderAst
|
|||
return clone;
|
||||
}
|
||||
|
||||
ExpressionPtr SanitizeVisitor::Clone(CallFunctionExpression& node)
|
||||
{
|
||||
constexpr std::size_t NoFunction = std::numeric_limits<std::size_t>::max();
|
||||
|
||||
auto clone = std::make_unique<CallFunctionExpression>();
|
||||
|
||||
clone->parameters.reserve(node.parameters.size());
|
||||
for (std::size_t i = 0; i < node.parameters.size(); ++i)
|
||||
clone->parameters.push_back(CloneExpression(node.parameters[i]));
|
||||
|
||||
const DeclareFunctionStatement* referenceFunctionDeclaration;
|
||||
if (std::holds_alternative<std::string>(node.targetFunction))
|
||||
{
|
||||
const std::string& functionName = std::get<std::string>(node.targetFunction);
|
||||
|
||||
const Identifier* identifier = FindIdentifier(functionName);
|
||||
if (identifier)
|
||||
{
|
||||
if (identifier->type == Identifier::Type::Intrinsic)
|
||||
{
|
||||
// Intrinsic function call
|
||||
std::vector<ExpressionPtr> parameters;
|
||||
parameters.reserve(node.parameters.size());
|
||||
|
||||
for (const auto& param : node.parameters)
|
||||
parameters.push_back(CloneExpression(param));
|
||||
|
||||
auto intrinsic = ShaderBuilder::Intrinsic(m_intrinsics[identifier->index], std::move(parameters));
|
||||
Validate(*intrinsic);
|
||||
|
||||
return intrinsic;
|
||||
}
|
||||
else
|
||||
{
|
||||
// Regular function call
|
||||
if (identifier->type != Identifier::Type::Function)
|
||||
throw AstError{ "function expected" };
|
||||
|
||||
clone->targetFunction = identifier->index;
|
||||
referenceFunctionDeclaration = m_functions[identifier->index];
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// Identifier not found, maybe the function is declared later
|
||||
auto it = m_functionDeclarations.find(functionName);
|
||||
if (it == m_functionDeclarations.end())
|
||||
throw AstError{ "function " + functionName + " does not exist" };
|
||||
|
||||
clone->targetFunction = it->second.second;
|
||||
|
||||
referenceFunctionDeclaration = it->second.first;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
std::size_t funcIndex = std::get<std::size_t>(node.targetFunction);
|
||||
referenceFunctionDeclaration = m_functions[funcIndex];
|
||||
}
|
||||
|
||||
Validate(*clone, referenceFunctionDeclaration);
|
||||
|
||||
return clone;
|
||||
}
|
||||
|
||||
ExpressionPtr SanitizeVisitor::Clone(CastExpression& node)
|
||||
{
|
||||
auto clone = static_unique_pointer_cast<CastExpression>(AstCloner::Clone(node));
|
||||
|
|
@ -426,15 +514,13 @@ namespace Nz::ShaderAst
|
|||
if (!identifier)
|
||||
throw AstError{ "unknown identifier " + node.identifier };
|
||||
|
||||
if (!std::holds_alternative<Variable>(identifier->value))
|
||||
if (identifier->type != Identifier::Type::Variable)
|
||||
throw AstError{ "expected variable identifier" };
|
||||
|
||||
const Variable& variable = std::get<Variable>(identifier->value);
|
||||
|
||||
// Replace IdentifierExpression by VariableExpression
|
||||
auto varExpr = std::make_unique<VariableExpression>();
|
||||
varExpr->cachedExpressionType = m_variables[variable.varIndex];
|
||||
varExpr->variableId = variable.varIndex;
|
||||
varExpr->cachedExpressionType = m_variableTypes[identifier->index];
|
||||
varExpr->variableId = identifier->index;
|
||||
|
||||
return varExpr;
|
||||
}
|
||||
|
|
@ -442,110 +528,7 @@ namespace Nz::ShaderAst
|
|||
ExpressionPtr SanitizeVisitor::Clone(IntrinsicExpression& node)
|
||||
{
|
||||
auto clone = static_unique_pointer_cast<IntrinsicExpression>(AstCloner::Clone(node));
|
||||
|
||||
// Parameter validation
|
||||
switch (clone->intrinsic)
|
||||
{
|
||||
case IntrinsicType::CrossProduct:
|
||||
case IntrinsicType::DotProduct:
|
||||
case IntrinsicType::Max:
|
||||
case IntrinsicType::Min:
|
||||
{
|
||||
if (clone->parameters.size() != 2)
|
||||
throw AstError { "Expected two parameters" };
|
||||
|
||||
for (auto& param : clone->parameters)
|
||||
MandatoryExpr(param);
|
||||
|
||||
const ExpressionType& type = GetExpressionType(*clone->parameters.front());
|
||||
|
||||
for (std::size_t i = 1; i < clone->parameters.size(); ++i)
|
||||
{
|
||||
if (type != GetExpressionType(*clone->parameters[i]))
|
||||
throw AstError{ "All type must match" };
|
||||
}
|
||||
|
||||
break;
|
||||
}
|
||||
|
||||
case IntrinsicType::Length:
|
||||
{
|
||||
if (clone->parameters.size() != 1)
|
||||
throw AstError{ "Expected only one parameters" };
|
||||
|
||||
for (auto& param : clone->parameters)
|
||||
MandatoryExpr(param);
|
||||
|
||||
const ExpressionType& type = GetExpressionType(*clone->parameters.front());
|
||||
if (!IsVectorType(type))
|
||||
throw AstError{ "Expected a vector" };
|
||||
|
||||
break;
|
||||
}
|
||||
|
||||
case IntrinsicType::SampleTexture:
|
||||
{
|
||||
if (clone->parameters.size() != 2)
|
||||
throw AstError{ "Expected two parameters" };
|
||||
|
||||
for (auto& param : clone->parameters)
|
||||
MandatoryExpr(param);
|
||||
|
||||
if (!IsSamplerType(GetExpressionType(*clone->parameters[0])))
|
||||
throw AstError{ "First parameter must be a sampler" };
|
||||
|
||||
if (!IsVectorType(GetExpressionType(*clone->parameters[1])))
|
||||
throw AstError{ "Second parameter must be a vector" };
|
||||
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Return type attribution
|
||||
switch (clone->intrinsic)
|
||||
{
|
||||
case IntrinsicType::CrossProduct:
|
||||
{
|
||||
const ExpressionType& type = GetExpressionType(*clone->parameters.front());
|
||||
if (type != ExpressionType{ VectorType{ 3, PrimitiveType::Float32 } })
|
||||
throw AstError{ "CrossProduct only works with vec3<f32> expressions" };
|
||||
|
||||
clone->cachedExpressionType = type;
|
||||
break;
|
||||
}
|
||||
|
||||
case IntrinsicType::DotProduct:
|
||||
case IntrinsicType::Length:
|
||||
{
|
||||
ExpressionType type = GetExpressionType(*clone->parameters.front());
|
||||
if (!IsVectorType(type))
|
||||
throw AstError{ "DotProduct expects vector types" };
|
||||
|
||||
clone->cachedExpressionType = std::get<VectorType>(type).type;
|
||||
break;
|
||||
}
|
||||
|
||||
case IntrinsicType::Max:
|
||||
case IntrinsicType::Min:
|
||||
{
|
||||
const ExpressionType& type = GetExpressionType(*clone->parameters.front());
|
||||
if (!IsPrimitiveType(type) && !IsVectorType(type))
|
||||
throw AstError{ "max and min only work with primitive and vector types" };
|
||||
|
||||
if ((IsPrimitiveType(type) && std::get<PrimitiveType>(type) == PrimitiveType::Boolean) ||
|
||||
(IsVectorType(type) && std::get<VectorType>(type).type == PrimitiveType::Boolean))
|
||||
throw AstError{ "max and min do not work with booleans" };
|
||||
|
||||
clone->cachedExpressionType = type;
|
||||
break;
|
||||
}
|
||||
|
||||
case IntrinsicType::SampleTexture:
|
||||
{
|
||||
clone->cachedExpressionType = VectorType{ 4, std::get<SamplerType>(GetExpressionType(*clone->parameters.front())).sampledType };
|
||||
break;
|
||||
}
|
||||
}
|
||||
Validate(*clone);
|
||||
|
||||
return clone;
|
||||
}
|
||||
|
|
@ -563,10 +546,10 @@ namespace Nz::ShaderAst
|
|||
if (!identifier)
|
||||
throw AstError{ "unknown option " + node.optionName };
|
||||
|
||||
if (!std::holds_alternative<Option>(identifier->value))
|
||||
if (identifier->type != Identifier::Type::Option)
|
||||
throw AstError{ "expected option identifier" };
|
||||
|
||||
condExpr->optionIndex = std::get<Option>(identifier->value).optionIndex;
|
||||
condExpr->optionIndex = identifier->index;
|
||||
|
||||
const ExpressionType& leftExprType = GetExpressionType(*condExpr->truePath);
|
||||
if (leftExprType != GetExpressionType(*condExpr->falsePath))
|
||||
|
|
@ -754,7 +737,6 @@ namespace Nz::ShaderAst
|
|||
auto clone = std::make_unique<DeclareFunctionStatement>();
|
||||
clone->entryStage = node.entryStage;
|
||||
clone->name = node.name;
|
||||
clone->funcIndex = m_nextFuncIndex++;
|
||||
clone->optionName = node.optionName;
|
||||
clone->parameters = node.parameters;
|
||||
clone->returnType = ResolveType(node.returnType);
|
||||
|
|
@ -785,14 +767,16 @@ namespace Nz::ShaderAst
|
|||
if (!identifier)
|
||||
throw AstError{ "unknown option " + node.optionName };
|
||||
|
||||
if (!std::holds_alternative<Option>(identifier->value))
|
||||
if (identifier->type != Identifier::Type::Option)
|
||||
throw AstError{ "expected option identifier" };
|
||||
|
||||
std::size_t optionIndex = std::get<Option>(identifier->value).optionIndex;
|
||||
std::size_t optionIndex = identifier->index;
|
||||
|
||||
return ShaderBuilder::ConditionalStatement(optionIndex, std::move(clone));
|
||||
}
|
||||
|
||||
clone->funcIndex = RegisterFunction(clone.get());
|
||||
|
||||
return clone;
|
||||
}
|
||||
|
||||
|
|
@ -905,6 +889,105 @@ namespace Nz::ShaderAst
|
|||
m_scopeSizes.pop_back();
|
||||
}
|
||||
|
||||
std::size_t SanitizeVisitor::RegisterFunction(DeclareFunctionStatement* funcDecl)
|
||||
{
|
||||
if (auto* identifier = FindIdentifier(funcDecl->name))
|
||||
{
|
||||
bool duplicate = true;
|
||||
|
||||
// Functions cannot be declared twice, except for entry ones if their stages are different
|
||||
if (funcDecl->entryStage && identifier->type == Identifier::Type::Function)
|
||||
{
|
||||
auto& otherFunction = m_functions[identifier->index];
|
||||
if (funcDecl->entryStage != otherFunction->entryStage)
|
||||
duplicate = false;
|
||||
}
|
||||
|
||||
if (duplicate)
|
||||
throw AstError{ funcDecl->name + " is already used" };
|
||||
}
|
||||
|
||||
std::size_t functionIndex = m_functions.size();
|
||||
m_functions.push_back(funcDecl);
|
||||
|
||||
m_identifiersInScope.push_back({
|
||||
funcDecl->name,
|
||||
functionIndex,
|
||||
Identifier::Type::Function
|
||||
});
|
||||
|
||||
return functionIndex;
|
||||
}
|
||||
|
||||
std::size_t SanitizeVisitor::RegisterIntrinsic(std::string name, IntrinsicType type)
|
||||
{
|
||||
if (FindIdentifier(name))
|
||||
throw AstError{ name + " is already used" };
|
||||
|
||||
std::size_t intrinsicIndex = m_intrinsics.size();
|
||||
m_intrinsics.push_back(type);
|
||||
|
||||
m_identifiersInScope.push_back({
|
||||
std::move(name),
|
||||
intrinsicIndex,
|
||||
Identifier::Type::Intrinsic
|
||||
});
|
||||
|
||||
return intrinsicIndex;
|
||||
}
|
||||
|
||||
std::size_t SanitizeVisitor::RegisterOption(std::string name, ExpressionType type)
|
||||
{
|
||||
if (FindIdentifier(name))
|
||||
throw AstError{ name + " is already used" };
|
||||
|
||||
std::size_t optionIndex = m_options.size();
|
||||
m_options.emplace_back(std::move(type));
|
||||
|
||||
m_identifiersInScope.push_back({
|
||||
std::move(name),
|
||||
optionIndex,
|
||||
Identifier::Type::Option
|
||||
});
|
||||
|
||||
return optionIndex;
|
||||
}
|
||||
|
||||
std::size_t SanitizeVisitor::RegisterStruct(std::string name, StructDescription description)
|
||||
{
|
||||
if (FindIdentifier(name))
|
||||
throw AstError{ name + " is already used" };
|
||||
|
||||
std::size_t structIndex = m_structs.size();
|
||||
m_structs.emplace_back(std::move(description));
|
||||
|
||||
m_identifiersInScope.push_back({
|
||||
std::move(name),
|
||||
structIndex,
|
||||
Identifier::Type::Struct
|
||||
});
|
||||
|
||||
return structIndex;
|
||||
}
|
||||
|
||||
std::size_t SanitizeVisitor::RegisterVariable(std::string name, ExpressionType type)
|
||||
{
|
||||
// Allow variable shadowing
|
||||
if (auto* identifier = FindIdentifier(name); identifier && identifier->type != Identifier::Type::Variable)
|
||||
throw AstError{ name + " is already used" };
|
||||
|
||||
std::size_t varIndex = m_variableTypes.size();
|
||||
m_variableTypes.emplace_back(std::move(type));
|
||||
|
||||
m_identifiersInScope.push_back({
|
||||
std::move(name),
|
||||
varIndex,
|
||||
Identifier::Type::Variable
|
||||
});
|
||||
|
||||
return varIndex;
|
||||
}
|
||||
|
||||
std::size_t SanitizeVisitor::ResolveStruct(const ExpressionType& exprType)
|
||||
{
|
||||
return std::visit([&](auto&& arg) -> std::size_t
|
||||
|
|
@ -932,10 +1015,10 @@ namespace Nz::ShaderAst
|
|||
if (!identifier)
|
||||
throw AstError{ "unknown identifier " + identifierType.name };
|
||||
|
||||
if (!std::holds_alternative<Struct>(identifier->value))
|
||||
if (identifier->type != Identifier::Type::Struct)
|
||||
throw AstError{ identifierType.name + " is not a struct" };
|
||||
|
||||
return std::get<Struct>(identifier->value).structIndex;
|
||||
return identifier->index;
|
||||
}
|
||||
|
||||
std::size_t SanitizeVisitor::ResolveStruct(const StructType& structType)
|
||||
|
|
@ -977,10 +1060,10 @@ namespace Nz::ShaderAst
|
|||
if (!identifier)
|
||||
throw AstError{ "unknown identifier " + arg.name };
|
||||
|
||||
if (!std::holds_alternative<Struct>(identifier->value))
|
||||
if (identifier->type != Identifier::Type::Struct)
|
||||
throw AstError{ "expected type identifier" };
|
||||
|
||||
return StructType{ std::get<Struct>(identifier->value).structIndex };
|
||||
return StructType{ identifier->index };
|
||||
}
|
||||
else if constexpr (std::is_same_v<T, UniformType>)
|
||||
{
|
||||
|
|
@ -1010,6 +1093,130 @@ namespace Nz::ShaderAst
|
|||
}
|
||||
}
|
||||
|
||||
void SanitizeVisitor::Validate(CallFunctionExpression& node, const DeclareFunctionStatement* referenceDeclaration)
|
||||
{
|
||||
if (referenceDeclaration->entryStage)
|
||||
throw AstError{ referenceDeclaration->name + " is an entry function which cannot be called by the program" };
|
||||
|
||||
for (std::size_t i = 0; i < node.parameters.size(); ++i)
|
||||
{
|
||||
if (GetExpressionType(*node.parameters[i]) != referenceDeclaration->parameters[i].type)
|
||||
throw AstError{ "function " + referenceDeclaration->name + " parameter " + std::to_string(i) + " type mismatch" };
|
||||
}
|
||||
|
||||
if (node.parameters.size() != referenceDeclaration->parameters.size())
|
||||
throw AstError{ "function " + referenceDeclaration->name + " expected " + std::to_string(referenceDeclaration->parameters.size()) + " parameters, got " + std::to_string(node.parameters.size()) };
|
||||
|
||||
node.cachedExpressionType = referenceDeclaration->returnType;
|
||||
}
|
||||
|
||||
void SanitizeVisitor::Validate(IntrinsicExpression& node)
|
||||
{
|
||||
// Parameter validation
|
||||
switch (node.intrinsic)
|
||||
{
|
||||
case IntrinsicType::CrossProduct:
|
||||
case IntrinsicType::DotProduct:
|
||||
case IntrinsicType::Max:
|
||||
case IntrinsicType::Min:
|
||||
{
|
||||
if (node.parameters.size() != 2)
|
||||
throw AstError { "Expected two parameters" };
|
||||
|
||||
for (auto& param : node.parameters)
|
||||
MandatoryExpr(param);
|
||||
|
||||
const ExpressionType& type = GetExpressionType(*node.parameters.front());
|
||||
|
||||
for (std::size_t i = 1; i < node.parameters.size(); ++i)
|
||||
{
|
||||
if (type != GetExpressionType(*node.parameters[i]))
|
||||
throw AstError{ "All type must match" };
|
||||
}
|
||||
|
||||
break;
|
||||
}
|
||||
|
||||
case IntrinsicType::Length:
|
||||
{
|
||||
if (node.parameters.size() != 1)
|
||||
throw AstError{ "Expected only one parameters" };
|
||||
|
||||
for (auto& param : node.parameters)
|
||||
MandatoryExpr(param);
|
||||
|
||||
const ExpressionType& type = GetExpressionType(*node.parameters.front());
|
||||
if (!IsVectorType(type))
|
||||
throw AstError{ "Expected a vector" };
|
||||
|
||||
break;
|
||||
}
|
||||
|
||||
case IntrinsicType::SampleTexture:
|
||||
{
|
||||
if (node.parameters.size() != 2)
|
||||
throw AstError{ "Expected two parameters" };
|
||||
|
||||
for (auto& param : node.parameters)
|
||||
MandatoryExpr(param);
|
||||
|
||||
if (!IsSamplerType(GetExpressionType(*node.parameters[0])))
|
||||
throw AstError{ "First parameter must be a sampler" };
|
||||
|
||||
if (!IsVectorType(GetExpressionType(*node.parameters[1])))
|
||||
throw AstError{ "Second parameter must be a vector" };
|
||||
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Return type attribution
|
||||
switch (node.intrinsic)
|
||||
{
|
||||
case IntrinsicType::CrossProduct:
|
||||
{
|
||||
const ExpressionType& type = GetExpressionType(*node.parameters.front());
|
||||
if (type != ExpressionType{ VectorType{ 3, PrimitiveType::Float32 } })
|
||||
throw AstError{ "CrossProduct only works with vec3<f32> expressions" };
|
||||
|
||||
node.cachedExpressionType = type;
|
||||
break;
|
||||
}
|
||||
|
||||
case IntrinsicType::DotProduct:
|
||||
case IntrinsicType::Length:
|
||||
{
|
||||
ExpressionType type = GetExpressionType(*node.parameters.front());
|
||||
if (!IsVectorType(type))
|
||||
throw AstError{ "DotProduct expects vector types" };
|
||||
|
||||
node.cachedExpressionType = std::get<VectorType>(type).type;
|
||||
break;
|
||||
}
|
||||
|
||||
case IntrinsicType::Max:
|
||||
case IntrinsicType::Min:
|
||||
{
|
||||
const ExpressionType& type = GetExpressionType(*node.parameters.front());
|
||||
if (!IsPrimitiveType(type) && !IsVectorType(type))
|
||||
throw AstError{ "max and min only work with primitive and vector types" };
|
||||
|
||||
if ((IsPrimitiveType(type) && std::get<PrimitiveType>(type) == PrimitiveType::Boolean) ||
|
||||
(IsVectorType(type) && std::get<VectorType>(type).type == PrimitiveType::Boolean))
|
||||
throw AstError{ "max and min do not work with booleans" };
|
||||
|
||||
node.cachedExpressionType = type;
|
||||
break;
|
||||
}
|
||||
|
||||
case IntrinsicType::SampleTexture:
|
||||
{
|
||||
node.cachedExpressionType = VectorType{ 4, std::get<SamplerType>(GetExpressionType(*node.parameters.front())).sampledType };
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void SanitizeVisitor::TypeMustMatch(ExpressionPtr& left, ExpressionPtr& right)
|
||||
{
|
||||
return TypeMustMatch(GetExpressionType(*left), GetExpressionType(*right));
|
||||
|
|
|
|||
|
|
@ -45,9 +45,9 @@ namespace Nz
|
|||
void Visit(ShaderAst::DeclareFunctionStatement& node) override
|
||||
{
|
||||
// Dismiss function if it's an entry point of another type than the one selected
|
||||
if (selectedStage)
|
||||
if (node.entryStage)
|
||||
{
|
||||
if (node.entryStage)
|
||||
if (selectedStage)
|
||||
{
|
||||
ShaderStageType stage = *node.entryStage;
|
||||
if (stage != *selectedStage)
|
||||
|
|
@ -56,15 +56,22 @@ namespace Nz
|
|||
assert(!entryPoint);
|
||||
entryPoint = &node;
|
||||
}
|
||||
else
|
||||
{
|
||||
assert(!entryPoint);
|
||||
entryPoint = &node;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
assert(!entryPoint);
|
||||
entryPoint = &node;
|
||||
}
|
||||
forwardFunctionDeclarations.push_back(&node);
|
||||
|
||||
assert(node.funcIndex);
|
||||
functionNames[node.funcIndex.value()] = node.name;
|
||||
}
|
||||
|
||||
std::optional<ShaderStageType> selectedStage;
|
||||
std::unordered_map<std::size_t, std::string> functionNames;
|
||||
std::vector<ShaderAst::DeclareFunctionStatement*> forwardFunctionDeclarations;
|
||||
ShaderAst::DeclareFunctionStatement* entryPoint = nullptr;
|
||||
UInt64 enabledOptions = 0;
|
||||
};
|
||||
|
|
@ -94,6 +101,7 @@ namespace Nz
|
|||
ShaderAst::DeclareFunctionStatement* entryFunc = nullptr;
|
||||
std::stringstream stream;
|
||||
std::unordered_map<std::size_t, ShaderAst::StructDescription> structs;
|
||||
std::unordered_map<std::size_t, std::string> functionNames;
|
||||
std::unordered_map<std::size_t, std::string> variableNames;
|
||||
std::vector<InOutField> inputFields;
|
||||
std::vector<InOutField> outputFields;
|
||||
|
|
@ -143,8 +151,9 @@ namespace Nz
|
|||
throw std::runtime_error("missing entry point");
|
||||
|
||||
state.entryFunc = previsitor.entryPoint;
|
||||
state.functionNames = std::move(previsitor.functionNames);
|
||||
|
||||
AppendHeader();
|
||||
AppendHeader(previsitor.forwardFunctionDeclarations);
|
||||
|
||||
sanitizedAst->Visit(*this);
|
||||
|
||||
|
|
@ -300,6 +309,23 @@ namespace Nz
|
|||
AppendLine();
|
||||
}
|
||||
|
||||
void GlslWriter::AppendFunctionDeclaration(const ShaderAst::DeclareFunctionStatement& node, bool forward)
|
||||
{
|
||||
Append(node.returnType, " ", node.name, "(");
|
||||
|
||||
bool first = true;
|
||||
for (const auto& parameter : node.parameters)
|
||||
{
|
||||
if (!first)
|
||||
Append(", ");
|
||||
|
||||
first = false;
|
||||
|
||||
Append(parameter.type, " ", parameter.name);
|
||||
}
|
||||
AppendLine((forward) ? ");" : ")");
|
||||
}
|
||||
|
||||
void GlslWriter::AppendField(std::size_t structIndex, const std::size_t* memberIndices, std::size_t remainingMembers)
|
||||
{
|
||||
const auto& structDesc = Retrieve(m_currentState->structs, structIndex);
|
||||
|
|
@ -315,6 +341,98 @@ namespace Nz
|
|||
AppendField(std::get<ShaderAst::StructType>(member.type).structIndex, memberIndices + 1, remainingMembers - 1);
|
||||
}
|
||||
}
|
||||
|
||||
void GlslWriter::AppendHeader(const std::vector<ShaderAst::DeclareFunctionStatement*>& forwardFunctionDeclarations)
|
||||
{
|
||||
unsigned int glslVersion;
|
||||
if (m_environment.glES)
|
||||
{
|
||||
if (m_environment.glMajorVersion >= 3 && m_environment.glMinorVersion >= 2)
|
||||
glslVersion = 320;
|
||||
else if (m_environment.glMajorVersion >= 3 && m_environment.glMinorVersion >= 1)
|
||||
glslVersion = 310;
|
||||
else if (m_environment.glMajorVersion >= 3)
|
||||
glslVersion = 300;
|
||||
else if (m_environment.glMajorVersion >= 2)
|
||||
glslVersion = 100;
|
||||
else
|
||||
throw std::runtime_error("This version of OpenGL ES does not support shaders");
|
||||
}
|
||||
else
|
||||
{
|
||||
if (m_environment.glMajorVersion >= 3 && m_environment.glMinorVersion >= 3)
|
||||
glslVersion = m_environment.glMajorVersion * 100 + m_environment.glMinorVersion * 10;
|
||||
else if (m_environment.glMajorVersion >= 3 && m_environment.glMinorVersion >= 2)
|
||||
glslVersion = 150;
|
||||
else if (m_environment.glMajorVersion >= 3 && m_environment.glMinorVersion >= 1)
|
||||
glslVersion = 140;
|
||||
else if (m_environment.glMajorVersion >= 3)
|
||||
glslVersion = 130;
|
||||
else if (m_environment.glMajorVersion >= 2 && m_environment.glMinorVersion >= 1)
|
||||
glslVersion = 120;
|
||||
else if (m_environment.glMajorVersion >= 2)
|
||||
glslVersion = 110;
|
||||
else
|
||||
throw std::runtime_error("This version of OpenGL does not support shaders");
|
||||
}
|
||||
|
||||
// Header
|
||||
Append("#version ");
|
||||
Append(glslVersion);
|
||||
if (m_environment.glES)
|
||||
Append(" es");
|
||||
|
||||
AppendLine();
|
||||
AppendLine();
|
||||
|
||||
// Extensions
|
||||
|
||||
std::vector<std::string> requiredExtensions;
|
||||
|
||||
if (!m_environment.glES && m_environment.extCallback)
|
||||
{
|
||||
// GL_ARB_shading_language_420pack (required for layout(binding = X))
|
||||
if (glslVersion < 420)
|
||||
{
|
||||
if (m_environment.extCallback("GL_ARB_shading_language_420pack"))
|
||||
requiredExtensions.emplace_back("GL_ARB_shading_language_420pack");
|
||||
}
|
||||
|
||||
// GL_ARB_separate_shader_objects (required for layout(location = X))
|
||||
if (glslVersion < 410)
|
||||
{
|
||||
if (m_environment.extCallback("GL_ARB_separate_shader_objects"))
|
||||
requiredExtensions.emplace_back("GL_ARB_separate_shader_objects");
|
||||
}
|
||||
}
|
||||
|
||||
if (!requiredExtensions.empty())
|
||||
{
|
||||
for (const std::string& ext : requiredExtensions)
|
||||
AppendLine("#extension " + ext + " : require");
|
||||
|
||||
AppendLine();
|
||||
}
|
||||
|
||||
if (m_environment.glES)
|
||||
{
|
||||
AppendLine("#if GL_FRAGMENT_PRECISION_HIGH");
|
||||
AppendLine("precision highp float;");
|
||||
AppendLine("#else");
|
||||
AppendLine("precision mediump float;");
|
||||
AppendLine("#endif");
|
||||
AppendLine();
|
||||
}
|
||||
|
||||
if (!forwardFunctionDeclarations.empty())
|
||||
{
|
||||
AppendCommentSection("function declarations");
|
||||
for (const ShaderAst::DeclareFunctionStatement* node : forwardFunctionDeclarations)
|
||||
AppendFunctionDeclaration(*node, true);
|
||||
|
||||
AppendLine();
|
||||
}
|
||||
}
|
||||
|
||||
void GlslWriter::AppendLine(const std::string& txt)
|
||||
{
|
||||
|
|
@ -481,6 +599,12 @@ namespace Nz
|
|||
}
|
||||
}
|
||||
|
||||
void GlslWriter::RegisterFunction(std::size_t funcIndex, std::string funcName)
|
||||
{
|
||||
assert(m_currentState->functionNames.find(funcIndex) == m_currentState->functionNames.end());
|
||||
m_currentState->functionNames.emplace(funcIndex, std::move(funcName));
|
||||
}
|
||||
|
||||
void GlslWriter::RegisterStruct(std::size_t structIndex, ShaderAst::StructDescription desc)
|
||||
{
|
||||
assert(m_currentState->structs.find(structIndex) == m_currentState->structs.end());
|
||||
|
|
@ -581,6 +705,22 @@ namespace Nz
|
|||
Visit(node.right, true);
|
||||
}
|
||||
|
||||
void GlslWriter::Visit(ShaderAst::CallFunctionExpression& node)
|
||||
{
|
||||
assert(std::holds_alternative<std::size_t>(node.targetFunction));
|
||||
const std::string& targetName = Retrieve(m_currentState->functionNames, std::get<std::size_t>(node.targetFunction));
|
||||
|
||||
Append(targetName, "(");
|
||||
for (std::size_t i = 0; i < node.parameters.size(); ++i)
|
||||
{
|
||||
if (i != 0)
|
||||
Append(", ");
|
||||
|
||||
node.parameters[i]->Visit(*this);
|
||||
}
|
||||
Append(")");
|
||||
}
|
||||
|
||||
void GlslWriter::Visit(ShaderAst::CastExpression& node)
|
||||
{
|
||||
Append(node.targetType);
|
||||
|
|
@ -720,25 +860,14 @@ namespace Nz
|
|||
|
||||
std::optional<std::size_t> varIndexOpt = node.varIndex;
|
||||
|
||||
Append(node.returnType);
|
||||
Append(" ");
|
||||
Append(node.name);
|
||||
Append("(");
|
||||
for (std::size_t i = 0; i < node.parameters.size(); ++i)
|
||||
for (const auto& parameter : node.parameters)
|
||||
{
|
||||
if (i != 0)
|
||||
Append(", ");
|
||||
|
||||
Append(node.parameters[i].type);
|
||||
Append(" ");
|
||||
Append(node.parameters[i].name);
|
||||
|
||||
assert(varIndexOpt);
|
||||
std::size_t& varIndex = *varIndexOpt;
|
||||
RegisterVariable(varIndex++, node.parameters[i].name);
|
||||
RegisterVariable(varIndex++, parameter.name);
|
||||
}
|
||||
Append(")\n");
|
||||
|
||||
AppendFunctionDeclaration(node);
|
||||
EnterScope();
|
||||
{
|
||||
AppendStatementList(node.statements);
|
||||
|
|
@ -989,88 +1118,4 @@ namespace Nz
|
|||
|
||||
return false;
|
||||
}
|
||||
|
||||
void GlslWriter::AppendHeader()
|
||||
{
|
||||
unsigned int glslVersion;
|
||||
if (m_environment.glES)
|
||||
{
|
||||
if (m_environment.glMajorVersion >= 3 && m_environment.glMinorVersion >= 2)
|
||||
glslVersion = 320;
|
||||
else if (m_environment.glMajorVersion >= 3 && m_environment.glMinorVersion >= 1)
|
||||
glslVersion = 310;
|
||||
else if (m_environment.glMajorVersion >= 3)
|
||||
glslVersion = 300;
|
||||
else if (m_environment.glMajorVersion >= 2)
|
||||
glslVersion = 100;
|
||||
else
|
||||
throw std::runtime_error("This version of OpenGL ES does not support shaders");
|
||||
}
|
||||
else
|
||||
{
|
||||
if (m_environment.glMajorVersion >= 3 && m_environment.glMinorVersion >= 3)
|
||||
glslVersion = m_environment.glMajorVersion * 100 + m_environment.glMinorVersion * 10;
|
||||
else if (m_environment.glMajorVersion >= 3 && m_environment.glMinorVersion >= 2)
|
||||
glslVersion = 150;
|
||||
else if (m_environment.glMajorVersion >= 3 && m_environment.glMinorVersion >= 1)
|
||||
glslVersion = 140;
|
||||
else if (m_environment.glMajorVersion >= 3)
|
||||
glslVersion = 130;
|
||||
else if (m_environment.glMajorVersion >= 2 && m_environment.glMinorVersion >= 1)
|
||||
glslVersion = 120;
|
||||
else if (m_environment.glMajorVersion >= 2)
|
||||
glslVersion = 110;
|
||||
else
|
||||
throw std::runtime_error("This version of OpenGL does not support shaders");
|
||||
}
|
||||
|
||||
// Header
|
||||
Append("#version ");
|
||||
Append(glslVersion);
|
||||
if (m_environment.glES)
|
||||
Append(" es");
|
||||
|
||||
AppendLine();
|
||||
AppendLine();
|
||||
|
||||
// Extensions
|
||||
|
||||
std::vector<std::string> requiredExtensions;
|
||||
|
||||
if (!m_environment.glES && m_environment.extCallback)
|
||||
{
|
||||
// GL_ARB_shading_language_420pack (required for layout(binding = X))
|
||||
if (glslVersion < 420)
|
||||
{
|
||||
if (m_environment.extCallback("GL_ARB_shading_language_420pack"))
|
||||
requiredExtensions.emplace_back("GL_ARB_shading_language_420pack");
|
||||
}
|
||||
|
||||
// GL_ARB_separate_shader_objects (required for layout(location = X))
|
||||
if (glslVersion < 410)
|
||||
{
|
||||
if (m_environment.extCallback("GL_ARB_separate_shader_objects"))
|
||||
requiredExtensions.emplace_back("GL_ARB_separate_shader_objects");
|
||||
}
|
||||
}
|
||||
|
||||
if (!requiredExtensions.empty())
|
||||
{
|
||||
for (const std::string& ext : requiredExtensions)
|
||||
AppendLine("#extension " + ext + " : require");
|
||||
|
||||
AppendLine();
|
||||
}
|
||||
|
||||
if (m_environment.glES)
|
||||
{
|
||||
AppendLine("#if GL_FRAGMENT_PRECISION_HIGH");
|
||||
AppendLine("precision highp float;");
|
||||
AppendLine("#else");
|
||||
AppendLine("precision mediump float;");
|
||||
AppendLine("#endif");
|
||||
AppendLine();
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
|||
|
|
@ -18,14 +18,6 @@ namespace Nz::ShaderLang
|
|||
{ "f32", ShaderAst::PrimitiveType::Float32 },
|
||||
{ "u32", ShaderAst::PrimitiveType::UInt32 }
|
||||
};
|
||||
|
||||
std::unordered_map<std::string, ShaderAst::IntrinsicType> s_identifierToIntrinsic = {
|
||||
{ "cross", ShaderAst::IntrinsicType::CrossProduct },
|
||||
{ "dot", ShaderAst::IntrinsicType::DotProduct },
|
||||
{ "max", ShaderAst::IntrinsicType::Max },
|
||||
{ "min", ShaderAst::IntrinsicType::Min },
|
||||
{ "length", ShaderAst::IntrinsicType::Length },
|
||||
};
|
||||
|
||||
std::unordered_map<std::string, ShaderAst::AttributeType> s_identifierToAttributeType = {
|
||||
{ "binding", ShaderAst::AttributeType::Binding },
|
||||
|
|
@ -137,14 +129,19 @@ namespace Nz::ShaderLang
|
|||
m_context->tokenIndex += count;
|
||||
}
|
||||
|
||||
ShaderAst::ExpressionType Parser::DecodeType(const std::string& identifier)
|
||||
std::optional<ShaderAst::ExpressionType> Parser::DecodeType(const std::string& identifier)
|
||||
{
|
||||
if (auto it = s_identifierToBasicType.find(identifier); it != s_identifierToBasicType.end())
|
||||
if (auto it = s_identifierToBasicType.find(identifier); it != s_identifierToBasicType.end())
|
||||
{
|
||||
Consume();
|
||||
return it->second;
|
||||
}
|
||||
|
||||
//FIXME: Handle this better
|
||||
if (identifier == "mat4")
|
||||
{
|
||||
{
|
||||
Consume();
|
||||
|
||||
ShaderAst::MatrixType matrixType;
|
||||
matrixType.columnCount = 4;
|
||||
matrixType.rowCount = 4;
|
||||
|
|
@ -156,7 +153,9 @@ namespace Nz::ShaderLang
|
|||
return matrixType;
|
||||
}
|
||||
else if (identifier == "sampler2D")
|
||||
{
|
||||
{
|
||||
Consume();
|
||||
|
||||
ShaderAst::SamplerType samplerType;
|
||||
samplerType.dim = ImageType_2D;
|
||||
|
||||
|
|
@ -167,7 +166,9 @@ namespace Nz::ShaderLang
|
|||
return samplerType;
|
||||
}
|
||||
else if (identifier == "uniform")
|
||||
{
|
||||
{
|
||||
Consume();
|
||||
|
||||
ShaderAst::UniformType uniformType;
|
||||
|
||||
Expect(Advance(), TokenType::LessThan); //< '<'
|
||||
|
|
@ -177,7 +178,9 @@ namespace Nz::ShaderLang
|
|||
return uniformType;
|
||||
}
|
||||
else if (identifier == "vec2")
|
||||
{
|
||||
{
|
||||
Consume();
|
||||
|
||||
ShaderAst::VectorType vectorType;
|
||||
vectorType.componentCount = 2;
|
||||
|
||||
|
|
@ -188,7 +191,9 @@ namespace Nz::ShaderLang
|
|||
return vectorType;
|
||||
}
|
||||
else if (identifier == "vec3")
|
||||
{
|
||||
{
|
||||
Consume();
|
||||
|
||||
ShaderAst::VectorType vectorType;
|
||||
vectorType.componentCount = 3;
|
||||
|
||||
|
|
@ -199,7 +204,9 @@ namespace Nz::ShaderLang
|
|||
return vectorType;
|
||||
}
|
||||
else if (identifier == "vec4")
|
||||
{
|
||||
{
|
||||
Consume();
|
||||
|
||||
ShaderAst::VectorType vectorType;
|
||||
vectorType.componentCount = 4;
|
||||
|
||||
|
|
@ -208,14 +215,9 @@ namespace Nz::ShaderLang
|
|||
Expect(Advance(), TokenType::GreatherThan); //< '>'
|
||||
|
||||
return vectorType;
|
||||
}
|
||||
else
|
||||
{
|
||||
ShaderAst::IdentifierType identifierType;
|
||||
identifierType.name = identifier;
|
||||
|
||||
return identifierType;
|
||||
}
|
||||
}
|
||||
else
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
void Parser::EnterScope()
|
||||
|
|
@ -873,23 +875,19 @@ namespace Nz::ShaderLang
|
|||
{
|
||||
const std::string& identifier = std::get<std::string>(token.data);
|
||||
|
||||
if (auto it = s_identifierToIntrinsic.find(identifier); it != s_identifierToIntrinsic.end())
|
||||
// Is it a cast?
|
||||
std::optional<ShaderAst::ExpressionType> exprType = DecodeType(identifier);
|
||||
if (exprType)
|
||||
return ShaderBuilder::Cast(std::move(*exprType), ParseParameters());
|
||||
|
||||
if (Peek(1).type == TokenType::OpenParenthesis)
|
||||
{
|
||||
if (Peek(1).type == TokenType::OpenParenthesis)
|
||||
{
|
||||
Consume();
|
||||
return ShaderBuilder::Intrinsic(it->second, ParseParameters());
|
||||
}
|
||||
// Function call
|
||||
Consume();
|
||||
return ShaderBuilder::CallFunction(identifier, ParseParameters());
|
||||
}
|
||||
|
||||
if (IsVariableInScope(identifier))
|
||||
return ParseIdentifier();
|
||||
|
||||
Consume();
|
||||
|
||||
ShaderAst::ExpressionType exprType = DecodeType(identifier);
|
||||
|
||||
return ShaderBuilder::Cast(std::move(exprType), ParseParameters());
|
||||
else
|
||||
return ParseIdentifier();
|
||||
}
|
||||
|
||||
case TokenType::IntegerValue:
|
||||
|
|
@ -989,10 +987,17 @@ namespace Nz::ShaderLang
|
|||
return ShaderAst::NoType{};
|
||||
}
|
||||
|
||||
const Token& identifierToken = Expect(Advance(), TokenType::Identifier);
|
||||
const Token& identifierToken = Expect(Peek(), TokenType::Identifier);
|
||||
const std::string& identifier = std::get<std::string>(identifierToken.data);
|
||||
|
||||
return DecodeType(identifier);
|
||||
auto type = DecodeType(identifier);
|
||||
if (!type)
|
||||
{
|
||||
Consume();
|
||||
return ShaderAst::IdentifierType{ identifier };
|
||||
}
|
||||
|
||||
return *type;
|
||||
}
|
||||
|
||||
int Parser::GetTokenPrecedence(TokenType token)
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@
|
|||
|
||||
#include <Nazara/Shader/SpirvAstVisitor.hpp>
|
||||
#include <Nazara/Core/CallOnExit.hpp>
|
||||
#include <Nazara/Core/StackArray.hpp>
|
||||
#include <Nazara/Core/StackVector.hpp>
|
||||
#include <Nazara/Shader/SpirvSection.hpp>
|
||||
#include <Nazara/Shader/SpirvExpressionLoad.hpp>
|
||||
|
|
@ -402,6 +403,49 @@ namespace Nz
|
|||
m_currentBlock = &m_functionBlocks.back();
|
||||
}
|
||||
|
||||
void SpirvAstVisitor::Visit(ShaderAst::CallFunctionExpression& node)
|
||||
{
|
||||
assert(std::holds_alternative<std::size_t>(node.targetFunction));
|
||||
std::size_t functionIndex = std::get<std::size_t>(node.targetFunction);
|
||||
|
||||
UInt32 funcId = 0;
|
||||
for (const auto& func : m_funcData)
|
||||
{
|
||||
if (func.funcIndex == functionIndex)
|
||||
{
|
||||
funcId = func.funcId;
|
||||
break;
|
||||
}
|
||||
}
|
||||
assert(funcId != 0);
|
||||
|
||||
const FuncData& funcData = m_funcData[m_funcIndex];
|
||||
const auto& funcCall = funcData.funcCalls[m_funcCallIndex++];
|
||||
|
||||
StackArray<UInt32> parameterIds = NazaraStackArrayNoInit(UInt32, node.parameters.size());
|
||||
for (std::size_t i = 0; i < node.parameters.size(); ++i)
|
||||
{
|
||||
UInt32 resultId = EvaluateExpression(node.parameters[i]);
|
||||
UInt32 varId = funcData.variables[funcCall.firstVarIndex + i].varId;
|
||||
m_currentBlock->Append(SpirvOp::OpStore, varId, resultId);
|
||||
|
||||
parameterIds[i] = varId;
|
||||
}
|
||||
|
||||
UInt32 resultId = AllocateResultId();
|
||||
m_currentBlock->AppendVariadic(SpirvOp::OpFunctionCall, [&](auto&& appender)
|
||||
{
|
||||
appender(m_writer.GetTypeId(ShaderAst::GetExpressionType(node)));
|
||||
appender(resultId);
|
||||
appender(funcId);
|
||||
|
||||
for (std::size_t i = 0; i < node.parameters.size(); ++i)
|
||||
appender(parameterIds[i]);
|
||||
});
|
||||
|
||||
PushResultId(resultId);
|
||||
}
|
||||
|
||||
void SpirvAstVisitor::Visit(ShaderAst::CastExpression& node)
|
||||
{
|
||||
const ShaderAst::ExpressionType& targetExprType = node.targetType;
|
||||
|
|
@ -561,9 +605,9 @@ namespace Nz
|
|||
{
|
||||
assert(node.funcIndex);
|
||||
m_funcIndex = *node.funcIndex;
|
||||
m_funcCallIndex = 0;
|
||||
|
||||
auto& func = m_funcData[m_funcIndex];
|
||||
func.funcId = m_writer.AllocateResultId();
|
||||
|
||||
m_instructions.Append(SpirvOp::OpFunction, func.returnTypeId, func.funcId, 0, func.funcTypeId);
|
||||
|
||||
|
|
|
|||
|
|
@ -141,6 +141,7 @@ namespace Nz
|
|||
|
||||
auto& funcData = m_funcs[funcIndex];
|
||||
funcData.name = node.name;
|
||||
funcData.funcIndex = funcIndex;
|
||||
|
||||
if (!entryPointType)
|
||||
{
|
||||
|
|
@ -228,7 +229,6 @@ namespace Nz
|
|||
*entryPointType,
|
||||
inputStruct,
|
||||
outputStructId,
|
||||
funcIndex,
|
||||
std::move(inputs),
|
||||
std::move(outputs)
|
||||
};
|
||||
|
|
@ -253,6 +253,23 @@ namespace Nz
|
|||
m_constantCache.Register(*m_constantCache.BuildType(node.description));
|
||||
}
|
||||
|
||||
void Visit(ShaderAst::CallFunctionExpression& node) override
|
||||
{
|
||||
AstRecursiveVisitor::Visit(node);
|
||||
|
||||
assert(m_funcIndex);
|
||||
auto& func = m_funcs[*m_funcIndex];
|
||||
|
||||
auto& funcCall = func.funcCalls.emplace_back();
|
||||
funcCall.firstVarIndex = func.variables.size();
|
||||
|
||||
for (const auto& parameter : node.parameters)
|
||||
{
|
||||
auto& var = func.variables.emplace_back();
|
||||
var.typeId = m_constantCache.Register(*m_constantCache.BuildPointerType(GetExpressionType(*parameter), SpirvStorageClass::Function));
|
||||
}
|
||||
}
|
||||
|
||||
void Visit(ShaderAst::DeclareVariableStatement& node) override
|
||||
{
|
||||
AstRecursiveVisitor::Visit(node);
|
||||
|
|
@ -440,6 +457,10 @@ namespace Nz
|
|||
for (const std::string& extInst : preVisitor.extInsts)
|
||||
state.extensionInstructionSet[extInst] = AllocateResultId();
|
||||
|
||||
// Assign function ID (required for forward declaration)
|
||||
for (auto& func : state.funcs)
|
||||
func.funcId = AllocateResultId();
|
||||
|
||||
SpirvAstVisitor visitor(*this, state.instructions, state.funcs);
|
||||
targetAst->Visit(visitor);
|
||||
|
||||
|
|
|
|||
|
|
@ -95,7 +95,7 @@ namespace Nz
|
|||
case ShaderLanguage::SpirV:
|
||||
{
|
||||
SpirvEntryPointExtractor extractor;
|
||||
extractor.Decode(reinterpret_cast<const Nz::UInt32*>(source), sourceSize);
|
||||
extractor.Decode(reinterpret_cast<const UInt32*>(source), sourceSize / sizeof(UInt32));
|
||||
|
||||
ShaderStageTypeFlags remainingStages = shaderStages;
|
||||
for (auto& entryPoint : extractor.entryPoints)
|
||||
|
|
|
|||
Loading…
Reference in New Issue