Shader: Add support for custom functions calls (and better handle intrinsics)

This commit is contained in:
Jérôme Leclercq
2021-05-22 13:37:54 +02:00
parent 8a6f0db034
commit f6fd996bf1
24 changed files with 777 additions and 356 deletions

View File

@@ -202,6 +202,36 @@ namespace Nz::ShaderAst
return clone;
}
ExpressionPtr AstCloner::Clone(CallFunctionExpression& node)
{
auto clone = std::make_unique<CallFunctionExpression>();
clone->targetFunction = node.targetFunction;
clone->parameters.reserve(node.parameters.size());
for (auto& parameter : node.parameters)
clone->parameters.push_back(CloneExpression(parameter));
clone->cachedExpressionType = node.cachedExpressionType;
return clone;
}
ExpressionPtr AstCloner::Clone(CallMethodExpression& node)
{
auto clone = std::make_unique<CallMethodExpression>();
clone->methodName = node.methodName;
clone->object = CloneExpression(node.object);
clone->parameters.reserve(node.parameters.size());
for (auto& parameter : node.parameters)
clone->parameters.push_back(CloneExpression(parameter));
clone->cachedExpressionType = node.cachedExpressionType;
return clone;
}
ExpressionPtr AstCloner::Clone(CastExpression& node)
{
auto clone = std::make_unique<CastExpression>();

View File

@@ -29,6 +29,20 @@ namespace Nz::ShaderAst
node.right->Visit(*this);
}
void AstRecursiveVisitor::Visit(CallFunctionExpression& node)
{
for (auto& param : node.parameters)
param->Visit(*this);
}
void AstRecursiveVisitor::Visit(CallMethodExpression& node)
{
node.object->Visit(*this);
for (auto& param : node.parameters)
param->Visit(*this);
}
void AstRecursiveVisitor::Visit(CastExpression& node)
{
for (auto& expr : node.expressions)

View File

@@ -65,6 +65,45 @@ namespace Nz::ShaderAst
Node(node.right);
}
void AstSerializerBase::Serialize(CallFunctionExpression& node)
{
UInt32 typeIndex;
if (IsWriting())
typeIndex = UInt32(node.targetFunction.index());
Value(typeIndex);
// Waiting for template lambda in C++20
auto SerializeValue = [&](auto dummyType)
{
using T = std::decay_t<decltype(dummyType)>;
auto& value = (IsWriting()) ? std::get<T>(node.targetFunction) : node.targetFunction.emplace<T>();
Value(value);
};
static_assert(std::variant_size_v<decltype(node.targetFunction)> == 2);
switch (typeIndex)
{
case 0: SerializeValue(std::string()); break;
case 1: SerializeValue(std::size_t()); break;
}
Container(node.parameters);
for (auto& param : node.parameters)
Node(param);
}
void AstSerializerBase::Serialize(CallMethodExpression& node)
{
Node(node.object);
Value(node.methodName);
Container(node.parameters);
for (auto& param : node.parameters)
Node(param);
}
void AstSerializerBase::Serialize(CastExpression& node)
{
Type(node.targetType);

View File

@@ -33,6 +33,16 @@ namespace Nz::ShaderAst
m_expressionCategory = ExpressionCategory::RValue;
}
void ShaderAstValueCategory::Visit(CallFunctionExpression& /*node*/)
{
m_expressionCategory = ExpressionCategory::RValue;
}
void ShaderAstValueCategory::Visit(CallMethodExpression& /*node*/)
{
m_expressionCategory = ExpressionCategory::RValue;
}
void ShaderAstValueCategory::Visit(CastExpression& /*node*/)
{
m_expressionCategory = ExpressionCategory::RValue;

View File

@@ -6,6 +6,7 @@
#include <Nazara/Core/CallOnExit.hpp>
#include <Nazara/Core/StackArray.hpp>
#include <Nazara/Shader/ShaderBuilder.hpp>
#include <Nazara/Shader/Ast/AstRecursiveVisitor.hpp>
#include <Nazara/Shader/Ast/AstUtils.hpp>
#include <stdexcept>
#include <unordered_set>
@@ -47,6 +48,28 @@ namespace Nz::ShaderAst
PushScope(); //< Global scope
{
RegisterIntrinsic("cross", IntrinsicType::CrossProduct);
RegisterIntrinsic("dot", IntrinsicType::DotProduct);
RegisterIntrinsic("max", IntrinsicType::Max);
RegisterIntrinsic("min", IntrinsicType::Min);
RegisterIntrinsic("length", IntrinsicType::Length);
// Collect function name and their types
if (nodePtr->GetType() == NodeType::MultiStatement)
{
std::size_t functionIndex = 0;
const MultiStatement& multiStatement = static_cast<const MultiStatement&>(*nodePtr);
for (const auto& statementPtr : multiStatement.statements)
{
if (statementPtr->GetType() == NodeType::DeclareFunctionStatement)
{
const DeclareFunctionStatement& funcDeclaration = static_cast<const DeclareFunctionStatement&>(*statementPtr);
m_functionDeclarations.emplace(funcDeclaration.name, std::make_pair(&funcDeclaration, functionIndex++));
}
}
}
try
{
clone = AstCloner::Clone(nodePtr);
@@ -355,6 +378,71 @@ namespace Nz::ShaderAst
return clone;
}
ExpressionPtr SanitizeVisitor::Clone(CallFunctionExpression& node)
{
constexpr std::size_t NoFunction = std::numeric_limits<std::size_t>::max();
auto clone = std::make_unique<CallFunctionExpression>();
clone->parameters.reserve(node.parameters.size());
for (std::size_t i = 0; i < node.parameters.size(); ++i)
clone->parameters.push_back(CloneExpression(node.parameters[i]));
const DeclareFunctionStatement* referenceFunctionDeclaration;
if (std::holds_alternative<std::string>(node.targetFunction))
{
const std::string& functionName = std::get<std::string>(node.targetFunction);
const Identifier* identifier = FindIdentifier(functionName);
if (identifier)
{
if (identifier->type == Identifier::Type::Intrinsic)
{
// Intrinsic function call
std::vector<ExpressionPtr> parameters;
parameters.reserve(node.parameters.size());
for (const auto& param : node.parameters)
parameters.push_back(CloneExpression(param));
auto intrinsic = ShaderBuilder::Intrinsic(m_intrinsics[identifier->index], std::move(parameters));
Validate(*intrinsic);
return intrinsic;
}
else
{
// Regular function call
if (identifier->type != Identifier::Type::Function)
throw AstError{ "function expected" };
clone->targetFunction = identifier->index;
referenceFunctionDeclaration = m_functions[identifier->index];
}
}
else
{
// Identifier not found, maybe the function is declared later
auto it = m_functionDeclarations.find(functionName);
if (it == m_functionDeclarations.end())
throw AstError{ "function " + functionName + " does not exist" };
clone->targetFunction = it->second.second;
referenceFunctionDeclaration = it->second.first;
}
}
else
{
std::size_t funcIndex = std::get<std::size_t>(node.targetFunction);
referenceFunctionDeclaration = m_functions[funcIndex];
}
Validate(*clone, referenceFunctionDeclaration);
return clone;
}
ExpressionPtr SanitizeVisitor::Clone(CastExpression& node)
{
auto clone = static_unique_pointer_cast<CastExpression>(AstCloner::Clone(node));
@@ -426,15 +514,13 @@ namespace Nz::ShaderAst
if (!identifier)
throw AstError{ "unknown identifier " + node.identifier };
if (!std::holds_alternative<Variable>(identifier->value))
if (identifier->type != Identifier::Type::Variable)
throw AstError{ "expected variable identifier" };
const Variable& variable = std::get<Variable>(identifier->value);
// Replace IdentifierExpression by VariableExpression
auto varExpr = std::make_unique<VariableExpression>();
varExpr->cachedExpressionType = m_variables[variable.varIndex];
varExpr->variableId = variable.varIndex;
varExpr->cachedExpressionType = m_variableTypes[identifier->index];
varExpr->variableId = identifier->index;
return varExpr;
}
@@ -442,110 +528,7 @@ namespace Nz::ShaderAst
ExpressionPtr SanitizeVisitor::Clone(IntrinsicExpression& node)
{
auto clone = static_unique_pointer_cast<IntrinsicExpression>(AstCloner::Clone(node));
// Parameter validation
switch (clone->intrinsic)
{
case IntrinsicType::CrossProduct:
case IntrinsicType::DotProduct:
case IntrinsicType::Max:
case IntrinsicType::Min:
{
if (clone->parameters.size() != 2)
throw AstError { "Expected two parameters" };
for (auto& param : clone->parameters)
MandatoryExpr(param);
const ExpressionType& type = GetExpressionType(*clone->parameters.front());
for (std::size_t i = 1; i < clone->parameters.size(); ++i)
{
if (type != GetExpressionType(*clone->parameters[i]))
throw AstError{ "All type must match" };
}
break;
}
case IntrinsicType::Length:
{
if (clone->parameters.size() != 1)
throw AstError{ "Expected only one parameters" };
for (auto& param : clone->parameters)
MandatoryExpr(param);
const ExpressionType& type = GetExpressionType(*clone->parameters.front());
if (!IsVectorType(type))
throw AstError{ "Expected a vector" };
break;
}
case IntrinsicType::SampleTexture:
{
if (clone->parameters.size() != 2)
throw AstError{ "Expected two parameters" };
for (auto& param : clone->parameters)
MandatoryExpr(param);
if (!IsSamplerType(GetExpressionType(*clone->parameters[0])))
throw AstError{ "First parameter must be a sampler" };
if (!IsVectorType(GetExpressionType(*clone->parameters[1])))
throw AstError{ "Second parameter must be a vector" };
break;
}
}
// Return type attribution
switch (clone->intrinsic)
{
case IntrinsicType::CrossProduct:
{
const ExpressionType& type = GetExpressionType(*clone->parameters.front());
if (type != ExpressionType{ VectorType{ 3, PrimitiveType::Float32 } })
throw AstError{ "CrossProduct only works with vec3<f32> expressions" };
clone->cachedExpressionType = type;
break;
}
case IntrinsicType::DotProduct:
case IntrinsicType::Length:
{
ExpressionType type = GetExpressionType(*clone->parameters.front());
if (!IsVectorType(type))
throw AstError{ "DotProduct expects vector types" };
clone->cachedExpressionType = std::get<VectorType>(type).type;
break;
}
case IntrinsicType::Max:
case IntrinsicType::Min:
{
const ExpressionType& type = GetExpressionType(*clone->parameters.front());
if (!IsPrimitiveType(type) && !IsVectorType(type))
throw AstError{ "max and min only work with primitive and vector types" };
if ((IsPrimitiveType(type) && std::get<PrimitiveType>(type) == PrimitiveType::Boolean) ||
(IsVectorType(type) && std::get<VectorType>(type).type == PrimitiveType::Boolean))
throw AstError{ "max and min do not work with booleans" };
clone->cachedExpressionType = type;
break;
}
case IntrinsicType::SampleTexture:
{
clone->cachedExpressionType = VectorType{ 4, std::get<SamplerType>(GetExpressionType(*clone->parameters.front())).sampledType };
break;
}
}
Validate(*clone);
return clone;
}
@@ -563,10 +546,10 @@ namespace Nz::ShaderAst
if (!identifier)
throw AstError{ "unknown option " + node.optionName };
if (!std::holds_alternative<Option>(identifier->value))
if (identifier->type != Identifier::Type::Option)
throw AstError{ "expected option identifier" };
condExpr->optionIndex = std::get<Option>(identifier->value).optionIndex;
condExpr->optionIndex = identifier->index;
const ExpressionType& leftExprType = GetExpressionType(*condExpr->truePath);
if (leftExprType != GetExpressionType(*condExpr->falsePath))
@@ -754,7 +737,6 @@ namespace Nz::ShaderAst
auto clone = std::make_unique<DeclareFunctionStatement>();
clone->entryStage = node.entryStage;
clone->name = node.name;
clone->funcIndex = m_nextFuncIndex++;
clone->optionName = node.optionName;
clone->parameters = node.parameters;
clone->returnType = ResolveType(node.returnType);
@@ -785,14 +767,16 @@ namespace Nz::ShaderAst
if (!identifier)
throw AstError{ "unknown option " + node.optionName };
if (!std::holds_alternative<Option>(identifier->value))
if (identifier->type != Identifier::Type::Option)
throw AstError{ "expected option identifier" };
std::size_t optionIndex = std::get<Option>(identifier->value).optionIndex;
std::size_t optionIndex = identifier->index;
return ShaderBuilder::ConditionalStatement(optionIndex, std::move(clone));
}
clone->funcIndex = RegisterFunction(clone.get());
return clone;
}
@@ -905,6 +889,105 @@ namespace Nz::ShaderAst
m_scopeSizes.pop_back();
}
std::size_t SanitizeVisitor::RegisterFunction(DeclareFunctionStatement* funcDecl)
{
if (auto* identifier = FindIdentifier(funcDecl->name))
{
bool duplicate = true;
// Functions cannot be declared twice, except for entry ones if their stages are different
if (funcDecl->entryStage && identifier->type == Identifier::Type::Function)
{
auto& otherFunction = m_functions[identifier->index];
if (funcDecl->entryStage != otherFunction->entryStage)
duplicate = false;
}
if (duplicate)
throw AstError{ funcDecl->name + " is already used" };
}
std::size_t functionIndex = m_functions.size();
m_functions.push_back(funcDecl);
m_identifiersInScope.push_back({
funcDecl->name,
functionIndex,
Identifier::Type::Function
});
return functionIndex;
}
std::size_t SanitizeVisitor::RegisterIntrinsic(std::string name, IntrinsicType type)
{
if (FindIdentifier(name))
throw AstError{ name + " is already used" };
std::size_t intrinsicIndex = m_intrinsics.size();
m_intrinsics.push_back(type);
m_identifiersInScope.push_back({
std::move(name),
intrinsicIndex,
Identifier::Type::Intrinsic
});
return intrinsicIndex;
}
std::size_t SanitizeVisitor::RegisterOption(std::string name, ExpressionType type)
{
if (FindIdentifier(name))
throw AstError{ name + " is already used" };
std::size_t optionIndex = m_options.size();
m_options.emplace_back(std::move(type));
m_identifiersInScope.push_back({
std::move(name),
optionIndex,
Identifier::Type::Option
});
return optionIndex;
}
std::size_t SanitizeVisitor::RegisterStruct(std::string name, StructDescription description)
{
if (FindIdentifier(name))
throw AstError{ name + " is already used" };
std::size_t structIndex = m_structs.size();
m_structs.emplace_back(std::move(description));
m_identifiersInScope.push_back({
std::move(name),
structIndex,
Identifier::Type::Struct
});
return structIndex;
}
std::size_t SanitizeVisitor::RegisterVariable(std::string name, ExpressionType type)
{
// Allow variable shadowing
if (auto* identifier = FindIdentifier(name); identifier && identifier->type != Identifier::Type::Variable)
throw AstError{ name + " is already used" };
std::size_t varIndex = m_variableTypes.size();
m_variableTypes.emplace_back(std::move(type));
m_identifiersInScope.push_back({
std::move(name),
varIndex,
Identifier::Type::Variable
});
return varIndex;
}
std::size_t SanitizeVisitor::ResolveStruct(const ExpressionType& exprType)
{
return std::visit([&](auto&& arg) -> std::size_t
@@ -932,10 +1015,10 @@ namespace Nz::ShaderAst
if (!identifier)
throw AstError{ "unknown identifier " + identifierType.name };
if (!std::holds_alternative<Struct>(identifier->value))
if (identifier->type != Identifier::Type::Struct)
throw AstError{ identifierType.name + " is not a struct" };
return std::get<Struct>(identifier->value).structIndex;
return identifier->index;
}
std::size_t SanitizeVisitor::ResolveStruct(const StructType& structType)
@@ -977,10 +1060,10 @@ namespace Nz::ShaderAst
if (!identifier)
throw AstError{ "unknown identifier " + arg.name };
if (!std::holds_alternative<Struct>(identifier->value))
if (identifier->type != Identifier::Type::Struct)
throw AstError{ "expected type identifier" };
return StructType{ std::get<Struct>(identifier->value).structIndex };
return StructType{ identifier->index };
}
else if constexpr (std::is_same_v<T, UniformType>)
{
@@ -1010,6 +1093,130 @@ namespace Nz::ShaderAst
}
}
void SanitizeVisitor::Validate(CallFunctionExpression& node, const DeclareFunctionStatement* referenceDeclaration)
{
if (referenceDeclaration->entryStage)
throw AstError{ referenceDeclaration->name + " is an entry function which cannot be called by the program" };
for (std::size_t i = 0; i < node.parameters.size(); ++i)
{
if (GetExpressionType(*node.parameters[i]) != referenceDeclaration->parameters[i].type)
throw AstError{ "function " + referenceDeclaration->name + " parameter " + std::to_string(i) + " type mismatch" };
}
if (node.parameters.size() != referenceDeclaration->parameters.size())
throw AstError{ "function " + referenceDeclaration->name + " expected " + std::to_string(referenceDeclaration->parameters.size()) + " parameters, got " + std::to_string(node.parameters.size()) };
node.cachedExpressionType = referenceDeclaration->returnType;
}
void SanitizeVisitor::Validate(IntrinsicExpression& node)
{
// Parameter validation
switch (node.intrinsic)
{
case IntrinsicType::CrossProduct:
case IntrinsicType::DotProduct:
case IntrinsicType::Max:
case IntrinsicType::Min:
{
if (node.parameters.size() != 2)
throw AstError { "Expected two parameters" };
for (auto& param : node.parameters)
MandatoryExpr(param);
const ExpressionType& type = GetExpressionType(*node.parameters.front());
for (std::size_t i = 1; i < node.parameters.size(); ++i)
{
if (type != GetExpressionType(*node.parameters[i]))
throw AstError{ "All type must match" };
}
break;
}
case IntrinsicType::Length:
{
if (node.parameters.size() != 1)
throw AstError{ "Expected only one parameters" };
for (auto& param : node.parameters)
MandatoryExpr(param);
const ExpressionType& type = GetExpressionType(*node.parameters.front());
if (!IsVectorType(type))
throw AstError{ "Expected a vector" };
break;
}
case IntrinsicType::SampleTexture:
{
if (node.parameters.size() != 2)
throw AstError{ "Expected two parameters" };
for (auto& param : node.parameters)
MandatoryExpr(param);
if (!IsSamplerType(GetExpressionType(*node.parameters[0])))
throw AstError{ "First parameter must be a sampler" };
if (!IsVectorType(GetExpressionType(*node.parameters[1])))
throw AstError{ "Second parameter must be a vector" };
break;
}
}
// Return type attribution
switch (node.intrinsic)
{
case IntrinsicType::CrossProduct:
{
const ExpressionType& type = GetExpressionType(*node.parameters.front());
if (type != ExpressionType{ VectorType{ 3, PrimitiveType::Float32 } })
throw AstError{ "CrossProduct only works with vec3<f32> expressions" };
node.cachedExpressionType = type;
break;
}
case IntrinsicType::DotProduct:
case IntrinsicType::Length:
{
ExpressionType type = GetExpressionType(*node.parameters.front());
if (!IsVectorType(type))
throw AstError{ "DotProduct expects vector types" };
node.cachedExpressionType = std::get<VectorType>(type).type;
break;
}
case IntrinsicType::Max:
case IntrinsicType::Min:
{
const ExpressionType& type = GetExpressionType(*node.parameters.front());
if (!IsPrimitiveType(type) && !IsVectorType(type))
throw AstError{ "max and min only work with primitive and vector types" };
if ((IsPrimitiveType(type) && std::get<PrimitiveType>(type) == PrimitiveType::Boolean) ||
(IsVectorType(type) && std::get<VectorType>(type).type == PrimitiveType::Boolean))
throw AstError{ "max and min do not work with booleans" };
node.cachedExpressionType = type;
break;
}
case IntrinsicType::SampleTexture:
{
node.cachedExpressionType = VectorType{ 4, std::get<SamplerType>(GetExpressionType(*node.parameters.front())).sampledType };
break;
}
}
}
void SanitizeVisitor::TypeMustMatch(ExpressionPtr& left, ExpressionPtr& right)
{
return TypeMustMatch(GetExpressionType(*left), GetExpressionType(*right));