|
|
|
|
@@ -6,6 +6,7 @@
|
|
|
|
|
#include <Nazara/Core/CallOnExit.hpp>
|
|
|
|
|
#include <Nazara/Core/StackArray.hpp>
|
|
|
|
|
#include <Nazara/Shader/ShaderBuilder.hpp>
|
|
|
|
|
#include <Nazara/Shader/Ast/AstRecursiveVisitor.hpp>
|
|
|
|
|
#include <Nazara/Shader/Ast/AstUtils.hpp>
|
|
|
|
|
#include <stdexcept>
|
|
|
|
|
#include <unordered_set>
|
|
|
|
|
@@ -47,6 +48,28 @@ namespace Nz::ShaderAst
|
|
|
|
|
|
|
|
|
|
PushScope(); //< Global scope
|
|
|
|
|
{
|
|
|
|
|
RegisterIntrinsic("cross", IntrinsicType::CrossProduct);
|
|
|
|
|
RegisterIntrinsic("dot", IntrinsicType::DotProduct);
|
|
|
|
|
RegisterIntrinsic("max", IntrinsicType::Max);
|
|
|
|
|
RegisterIntrinsic("min", IntrinsicType::Min);
|
|
|
|
|
RegisterIntrinsic("length", IntrinsicType::Length);
|
|
|
|
|
|
|
|
|
|
// Collect function name and their types
|
|
|
|
|
if (nodePtr->GetType() == NodeType::MultiStatement)
|
|
|
|
|
{
|
|
|
|
|
std::size_t functionIndex = 0;
|
|
|
|
|
|
|
|
|
|
const MultiStatement& multiStatement = static_cast<const MultiStatement&>(*nodePtr);
|
|
|
|
|
for (const auto& statementPtr : multiStatement.statements)
|
|
|
|
|
{
|
|
|
|
|
if (statementPtr->GetType() == NodeType::DeclareFunctionStatement)
|
|
|
|
|
{
|
|
|
|
|
const DeclareFunctionStatement& funcDeclaration = static_cast<const DeclareFunctionStatement&>(*statementPtr);
|
|
|
|
|
m_functionDeclarations.emplace(funcDeclaration.name, std::make_pair(&funcDeclaration, functionIndex++));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
try
|
|
|
|
|
{
|
|
|
|
|
clone = AstCloner::Clone(nodePtr);
|
|
|
|
|
@@ -355,6 +378,71 @@ namespace Nz::ShaderAst
|
|
|
|
|
return clone;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ExpressionPtr SanitizeVisitor::Clone(CallFunctionExpression& node)
|
|
|
|
|
{
|
|
|
|
|
constexpr std::size_t NoFunction = std::numeric_limits<std::size_t>::max();
|
|
|
|
|
|
|
|
|
|
auto clone = std::make_unique<CallFunctionExpression>();
|
|
|
|
|
|
|
|
|
|
clone->parameters.reserve(node.parameters.size());
|
|
|
|
|
for (std::size_t i = 0; i < node.parameters.size(); ++i)
|
|
|
|
|
clone->parameters.push_back(CloneExpression(node.parameters[i]));
|
|
|
|
|
|
|
|
|
|
const DeclareFunctionStatement* referenceFunctionDeclaration;
|
|
|
|
|
if (std::holds_alternative<std::string>(node.targetFunction))
|
|
|
|
|
{
|
|
|
|
|
const std::string& functionName = std::get<std::string>(node.targetFunction);
|
|
|
|
|
|
|
|
|
|
const Identifier* identifier = FindIdentifier(functionName);
|
|
|
|
|
if (identifier)
|
|
|
|
|
{
|
|
|
|
|
if (identifier->type == Identifier::Type::Intrinsic)
|
|
|
|
|
{
|
|
|
|
|
// Intrinsic function call
|
|
|
|
|
std::vector<ExpressionPtr> parameters;
|
|
|
|
|
parameters.reserve(node.parameters.size());
|
|
|
|
|
|
|
|
|
|
for (const auto& param : node.parameters)
|
|
|
|
|
parameters.push_back(CloneExpression(param));
|
|
|
|
|
|
|
|
|
|
auto intrinsic = ShaderBuilder::Intrinsic(m_intrinsics[identifier->index], std::move(parameters));
|
|
|
|
|
Validate(*intrinsic);
|
|
|
|
|
|
|
|
|
|
return intrinsic;
|
|
|
|
|
}
|
|
|
|
|
else
|
|
|
|
|
{
|
|
|
|
|
// Regular function call
|
|
|
|
|
if (identifier->type != Identifier::Type::Function)
|
|
|
|
|
throw AstError{ "function expected" };
|
|
|
|
|
|
|
|
|
|
clone->targetFunction = identifier->index;
|
|
|
|
|
referenceFunctionDeclaration = m_functions[identifier->index];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
else
|
|
|
|
|
{
|
|
|
|
|
// Identifier not found, maybe the function is declared later
|
|
|
|
|
auto it = m_functionDeclarations.find(functionName);
|
|
|
|
|
if (it == m_functionDeclarations.end())
|
|
|
|
|
throw AstError{ "function " + functionName + " does not exist" };
|
|
|
|
|
|
|
|
|
|
clone->targetFunction = it->second.second;
|
|
|
|
|
|
|
|
|
|
referenceFunctionDeclaration = it->second.first;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
else
|
|
|
|
|
{
|
|
|
|
|
std::size_t funcIndex = std::get<std::size_t>(node.targetFunction);
|
|
|
|
|
referenceFunctionDeclaration = m_functions[funcIndex];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Validate(*clone, referenceFunctionDeclaration);
|
|
|
|
|
|
|
|
|
|
return clone;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ExpressionPtr SanitizeVisitor::Clone(CastExpression& node)
|
|
|
|
|
{
|
|
|
|
|
auto clone = static_unique_pointer_cast<CastExpression>(AstCloner::Clone(node));
|
|
|
|
|
@@ -426,15 +514,13 @@ namespace Nz::ShaderAst
|
|
|
|
|
if (!identifier)
|
|
|
|
|
throw AstError{ "unknown identifier " + node.identifier };
|
|
|
|
|
|
|
|
|
|
if (!std::holds_alternative<Variable>(identifier->value))
|
|
|
|
|
if (identifier->type != Identifier::Type::Variable)
|
|
|
|
|
throw AstError{ "expected variable identifier" };
|
|
|
|
|
|
|
|
|
|
const Variable& variable = std::get<Variable>(identifier->value);
|
|
|
|
|
|
|
|
|
|
// Replace IdentifierExpression by VariableExpression
|
|
|
|
|
auto varExpr = std::make_unique<VariableExpression>();
|
|
|
|
|
varExpr->cachedExpressionType = m_variables[variable.varIndex];
|
|
|
|
|
varExpr->variableId = variable.varIndex;
|
|
|
|
|
varExpr->cachedExpressionType = m_variableTypes[identifier->index];
|
|
|
|
|
varExpr->variableId = identifier->index;
|
|
|
|
|
|
|
|
|
|
return varExpr;
|
|
|
|
|
}
|
|
|
|
|
@@ -442,110 +528,7 @@ namespace Nz::ShaderAst
|
|
|
|
|
ExpressionPtr SanitizeVisitor::Clone(IntrinsicExpression& node)
|
|
|
|
|
{
|
|
|
|
|
auto clone = static_unique_pointer_cast<IntrinsicExpression>(AstCloner::Clone(node));
|
|
|
|
|
|
|
|
|
|
// Parameter validation
|
|
|
|
|
switch (clone->intrinsic)
|
|
|
|
|
{
|
|
|
|
|
case IntrinsicType::CrossProduct:
|
|
|
|
|
case IntrinsicType::DotProduct:
|
|
|
|
|
case IntrinsicType::Max:
|
|
|
|
|
case IntrinsicType::Min:
|
|
|
|
|
{
|
|
|
|
|
if (clone->parameters.size() != 2)
|
|
|
|
|
throw AstError { "Expected two parameters" };
|
|
|
|
|
|
|
|
|
|
for (auto& param : clone->parameters)
|
|
|
|
|
MandatoryExpr(param);
|
|
|
|
|
|
|
|
|
|
const ExpressionType& type = GetExpressionType(*clone->parameters.front());
|
|
|
|
|
|
|
|
|
|
for (std::size_t i = 1; i < clone->parameters.size(); ++i)
|
|
|
|
|
{
|
|
|
|
|
if (type != GetExpressionType(*clone->parameters[i]))
|
|
|
|
|
throw AstError{ "All type must match" };
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
case IntrinsicType::Length:
|
|
|
|
|
{
|
|
|
|
|
if (clone->parameters.size() != 1)
|
|
|
|
|
throw AstError{ "Expected only one parameters" };
|
|
|
|
|
|
|
|
|
|
for (auto& param : clone->parameters)
|
|
|
|
|
MandatoryExpr(param);
|
|
|
|
|
|
|
|
|
|
const ExpressionType& type = GetExpressionType(*clone->parameters.front());
|
|
|
|
|
if (!IsVectorType(type))
|
|
|
|
|
throw AstError{ "Expected a vector" };
|
|
|
|
|
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
case IntrinsicType::SampleTexture:
|
|
|
|
|
{
|
|
|
|
|
if (clone->parameters.size() != 2)
|
|
|
|
|
throw AstError{ "Expected two parameters" };
|
|
|
|
|
|
|
|
|
|
for (auto& param : clone->parameters)
|
|
|
|
|
MandatoryExpr(param);
|
|
|
|
|
|
|
|
|
|
if (!IsSamplerType(GetExpressionType(*clone->parameters[0])))
|
|
|
|
|
throw AstError{ "First parameter must be a sampler" };
|
|
|
|
|
|
|
|
|
|
if (!IsVectorType(GetExpressionType(*clone->parameters[1])))
|
|
|
|
|
throw AstError{ "Second parameter must be a vector" };
|
|
|
|
|
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Return type attribution
|
|
|
|
|
switch (clone->intrinsic)
|
|
|
|
|
{
|
|
|
|
|
case IntrinsicType::CrossProduct:
|
|
|
|
|
{
|
|
|
|
|
const ExpressionType& type = GetExpressionType(*clone->parameters.front());
|
|
|
|
|
if (type != ExpressionType{ VectorType{ 3, PrimitiveType::Float32 } })
|
|
|
|
|
throw AstError{ "CrossProduct only works with vec3<f32> expressions" };
|
|
|
|
|
|
|
|
|
|
clone->cachedExpressionType = type;
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
case IntrinsicType::DotProduct:
|
|
|
|
|
case IntrinsicType::Length:
|
|
|
|
|
{
|
|
|
|
|
ExpressionType type = GetExpressionType(*clone->parameters.front());
|
|
|
|
|
if (!IsVectorType(type))
|
|
|
|
|
throw AstError{ "DotProduct expects vector types" };
|
|
|
|
|
|
|
|
|
|
clone->cachedExpressionType = std::get<VectorType>(type).type;
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
case IntrinsicType::Max:
|
|
|
|
|
case IntrinsicType::Min:
|
|
|
|
|
{
|
|
|
|
|
const ExpressionType& type = GetExpressionType(*clone->parameters.front());
|
|
|
|
|
if (!IsPrimitiveType(type) && !IsVectorType(type))
|
|
|
|
|
throw AstError{ "max and min only work with primitive and vector types" };
|
|
|
|
|
|
|
|
|
|
if ((IsPrimitiveType(type) && std::get<PrimitiveType>(type) == PrimitiveType::Boolean) ||
|
|
|
|
|
(IsVectorType(type) && std::get<VectorType>(type).type == PrimitiveType::Boolean))
|
|
|
|
|
throw AstError{ "max and min do not work with booleans" };
|
|
|
|
|
|
|
|
|
|
clone->cachedExpressionType = type;
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
case IntrinsicType::SampleTexture:
|
|
|
|
|
{
|
|
|
|
|
clone->cachedExpressionType = VectorType{ 4, std::get<SamplerType>(GetExpressionType(*clone->parameters.front())).sampledType };
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
Validate(*clone);
|
|
|
|
|
|
|
|
|
|
return clone;
|
|
|
|
|
}
|
|
|
|
|
@@ -563,10 +546,10 @@ namespace Nz::ShaderAst
|
|
|
|
|
if (!identifier)
|
|
|
|
|
throw AstError{ "unknown option " + node.optionName };
|
|
|
|
|
|
|
|
|
|
if (!std::holds_alternative<Option>(identifier->value))
|
|
|
|
|
if (identifier->type != Identifier::Type::Option)
|
|
|
|
|
throw AstError{ "expected option identifier" };
|
|
|
|
|
|
|
|
|
|
condExpr->optionIndex = std::get<Option>(identifier->value).optionIndex;
|
|
|
|
|
condExpr->optionIndex = identifier->index;
|
|
|
|
|
|
|
|
|
|
const ExpressionType& leftExprType = GetExpressionType(*condExpr->truePath);
|
|
|
|
|
if (leftExprType != GetExpressionType(*condExpr->falsePath))
|
|
|
|
|
@@ -754,7 +737,6 @@ namespace Nz::ShaderAst
|
|
|
|
|
auto clone = std::make_unique<DeclareFunctionStatement>();
|
|
|
|
|
clone->entryStage = node.entryStage;
|
|
|
|
|
clone->name = node.name;
|
|
|
|
|
clone->funcIndex = m_nextFuncIndex++;
|
|
|
|
|
clone->optionName = node.optionName;
|
|
|
|
|
clone->parameters = node.parameters;
|
|
|
|
|
clone->returnType = ResolveType(node.returnType);
|
|
|
|
|
@@ -785,14 +767,16 @@ namespace Nz::ShaderAst
|
|
|
|
|
if (!identifier)
|
|
|
|
|
throw AstError{ "unknown option " + node.optionName };
|
|
|
|
|
|
|
|
|
|
if (!std::holds_alternative<Option>(identifier->value))
|
|
|
|
|
if (identifier->type != Identifier::Type::Option)
|
|
|
|
|
throw AstError{ "expected option identifier" };
|
|
|
|
|
|
|
|
|
|
std::size_t optionIndex = std::get<Option>(identifier->value).optionIndex;
|
|
|
|
|
std::size_t optionIndex = identifier->index;
|
|
|
|
|
|
|
|
|
|
return ShaderBuilder::ConditionalStatement(optionIndex, std::move(clone));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
clone->funcIndex = RegisterFunction(clone.get());
|
|
|
|
|
|
|
|
|
|
return clone;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@@ -905,6 +889,105 @@ namespace Nz::ShaderAst
|
|
|
|
|
m_scopeSizes.pop_back();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::size_t SanitizeVisitor::RegisterFunction(DeclareFunctionStatement* funcDecl)
|
|
|
|
|
{
|
|
|
|
|
if (auto* identifier = FindIdentifier(funcDecl->name))
|
|
|
|
|
{
|
|
|
|
|
bool duplicate = true;
|
|
|
|
|
|
|
|
|
|
// Functions cannot be declared twice, except for entry ones if their stages are different
|
|
|
|
|
if (funcDecl->entryStage && identifier->type == Identifier::Type::Function)
|
|
|
|
|
{
|
|
|
|
|
auto& otherFunction = m_functions[identifier->index];
|
|
|
|
|
if (funcDecl->entryStage != otherFunction->entryStage)
|
|
|
|
|
duplicate = false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (duplicate)
|
|
|
|
|
throw AstError{ funcDecl->name + " is already used" };
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::size_t functionIndex = m_functions.size();
|
|
|
|
|
m_functions.push_back(funcDecl);
|
|
|
|
|
|
|
|
|
|
m_identifiersInScope.push_back({
|
|
|
|
|
funcDecl->name,
|
|
|
|
|
functionIndex,
|
|
|
|
|
Identifier::Type::Function
|
|
|
|
|
});
|
|
|
|
|
|
|
|
|
|
return functionIndex;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::size_t SanitizeVisitor::RegisterIntrinsic(std::string name, IntrinsicType type)
|
|
|
|
|
{
|
|
|
|
|
if (FindIdentifier(name))
|
|
|
|
|
throw AstError{ name + " is already used" };
|
|
|
|
|
|
|
|
|
|
std::size_t intrinsicIndex = m_intrinsics.size();
|
|
|
|
|
m_intrinsics.push_back(type);
|
|
|
|
|
|
|
|
|
|
m_identifiersInScope.push_back({
|
|
|
|
|
std::move(name),
|
|
|
|
|
intrinsicIndex,
|
|
|
|
|
Identifier::Type::Intrinsic
|
|
|
|
|
});
|
|
|
|
|
|
|
|
|
|
return intrinsicIndex;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::size_t SanitizeVisitor::RegisterOption(std::string name, ExpressionType type)
|
|
|
|
|
{
|
|
|
|
|
if (FindIdentifier(name))
|
|
|
|
|
throw AstError{ name + " is already used" };
|
|
|
|
|
|
|
|
|
|
std::size_t optionIndex = m_options.size();
|
|
|
|
|
m_options.emplace_back(std::move(type));
|
|
|
|
|
|
|
|
|
|
m_identifiersInScope.push_back({
|
|
|
|
|
std::move(name),
|
|
|
|
|
optionIndex,
|
|
|
|
|
Identifier::Type::Option
|
|
|
|
|
});
|
|
|
|
|
|
|
|
|
|
return optionIndex;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::size_t SanitizeVisitor::RegisterStruct(std::string name, StructDescription description)
|
|
|
|
|
{
|
|
|
|
|
if (FindIdentifier(name))
|
|
|
|
|
throw AstError{ name + " is already used" };
|
|
|
|
|
|
|
|
|
|
std::size_t structIndex = m_structs.size();
|
|
|
|
|
m_structs.emplace_back(std::move(description));
|
|
|
|
|
|
|
|
|
|
m_identifiersInScope.push_back({
|
|
|
|
|
std::move(name),
|
|
|
|
|
structIndex,
|
|
|
|
|
Identifier::Type::Struct
|
|
|
|
|
});
|
|
|
|
|
|
|
|
|
|
return structIndex;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::size_t SanitizeVisitor::RegisterVariable(std::string name, ExpressionType type)
|
|
|
|
|
{
|
|
|
|
|
// Allow variable shadowing
|
|
|
|
|
if (auto* identifier = FindIdentifier(name); identifier && identifier->type != Identifier::Type::Variable)
|
|
|
|
|
throw AstError{ name + " is already used" };
|
|
|
|
|
|
|
|
|
|
std::size_t varIndex = m_variableTypes.size();
|
|
|
|
|
m_variableTypes.emplace_back(std::move(type));
|
|
|
|
|
|
|
|
|
|
m_identifiersInScope.push_back({
|
|
|
|
|
std::move(name),
|
|
|
|
|
varIndex,
|
|
|
|
|
Identifier::Type::Variable
|
|
|
|
|
});
|
|
|
|
|
|
|
|
|
|
return varIndex;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::size_t SanitizeVisitor::ResolveStruct(const ExpressionType& exprType)
|
|
|
|
|
{
|
|
|
|
|
return std::visit([&](auto&& arg) -> std::size_t
|
|
|
|
|
@@ -932,10 +1015,10 @@ namespace Nz::ShaderAst
|
|
|
|
|
if (!identifier)
|
|
|
|
|
throw AstError{ "unknown identifier " + identifierType.name };
|
|
|
|
|
|
|
|
|
|
if (!std::holds_alternative<Struct>(identifier->value))
|
|
|
|
|
if (identifier->type != Identifier::Type::Struct)
|
|
|
|
|
throw AstError{ identifierType.name + " is not a struct" };
|
|
|
|
|
|
|
|
|
|
return std::get<Struct>(identifier->value).structIndex;
|
|
|
|
|
return identifier->index;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::size_t SanitizeVisitor::ResolveStruct(const StructType& structType)
|
|
|
|
|
@@ -977,10 +1060,10 @@ namespace Nz::ShaderAst
|
|
|
|
|
if (!identifier)
|
|
|
|
|
throw AstError{ "unknown identifier " + arg.name };
|
|
|
|
|
|
|
|
|
|
if (!std::holds_alternative<Struct>(identifier->value))
|
|
|
|
|
if (identifier->type != Identifier::Type::Struct)
|
|
|
|
|
throw AstError{ "expected type identifier" };
|
|
|
|
|
|
|
|
|
|
return StructType{ std::get<Struct>(identifier->value).structIndex };
|
|
|
|
|
return StructType{ identifier->index };
|
|
|
|
|
}
|
|
|
|
|
else if constexpr (std::is_same_v<T, UniformType>)
|
|
|
|
|
{
|
|
|
|
|
@@ -1010,6 +1093,130 @@ namespace Nz::ShaderAst
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void SanitizeVisitor::Validate(CallFunctionExpression& node, const DeclareFunctionStatement* referenceDeclaration)
|
|
|
|
|
{
|
|
|
|
|
if (referenceDeclaration->entryStage)
|
|
|
|
|
throw AstError{ referenceDeclaration->name + " is an entry function which cannot be called by the program" };
|
|
|
|
|
|
|
|
|
|
for (std::size_t i = 0; i < node.parameters.size(); ++i)
|
|
|
|
|
{
|
|
|
|
|
if (GetExpressionType(*node.parameters[i]) != referenceDeclaration->parameters[i].type)
|
|
|
|
|
throw AstError{ "function " + referenceDeclaration->name + " parameter " + std::to_string(i) + " type mismatch" };
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (node.parameters.size() != referenceDeclaration->parameters.size())
|
|
|
|
|
throw AstError{ "function " + referenceDeclaration->name + " expected " + std::to_string(referenceDeclaration->parameters.size()) + " parameters, got " + std::to_string(node.parameters.size()) };
|
|
|
|
|
|
|
|
|
|
node.cachedExpressionType = referenceDeclaration->returnType;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void SanitizeVisitor::Validate(IntrinsicExpression& node)
|
|
|
|
|
{
|
|
|
|
|
// Parameter validation
|
|
|
|
|
switch (node.intrinsic)
|
|
|
|
|
{
|
|
|
|
|
case IntrinsicType::CrossProduct:
|
|
|
|
|
case IntrinsicType::DotProduct:
|
|
|
|
|
case IntrinsicType::Max:
|
|
|
|
|
case IntrinsicType::Min:
|
|
|
|
|
{
|
|
|
|
|
if (node.parameters.size() != 2)
|
|
|
|
|
throw AstError { "Expected two parameters" };
|
|
|
|
|
|
|
|
|
|
for (auto& param : node.parameters)
|
|
|
|
|
MandatoryExpr(param);
|
|
|
|
|
|
|
|
|
|
const ExpressionType& type = GetExpressionType(*node.parameters.front());
|
|
|
|
|
|
|
|
|
|
for (std::size_t i = 1; i < node.parameters.size(); ++i)
|
|
|
|
|
{
|
|
|
|
|
if (type != GetExpressionType(*node.parameters[i]))
|
|
|
|
|
throw AstError{ "All type must match" };
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
case IntrinsicType::Length:
|
|
|
|
|
{
|
|
|
|
|
if (node.parameters.size() != 1)
|
|
|
|
|
throw AstError{ "Expected only one parameters" };
|
|
|
|
|
|
|
|
|
|
for (auto& param : node.parameters)
|
|
|
|
|
MandatoryExpr(param);
|
|
|
|
|
|
|
|
|
|
const ExpressionType& type = GetExpressionType(*node.parameters.front());
|
|
|
|
|
if (!IsVectorType(type))
|
|
|
|
|
throw AstError{ "Expected a vector" };
|
|
|
|
|
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
case IntrinsicType::SampleTexture:
|
|
|
|
|
{
|
|
|
|
|
if (node.parameters.size() != 2)
|
|
|
|
|
throw AstError{ "Expected two parameters" };
|
|
|
|
|
|
|
|
|
|
for (auto& param : node.parameters)
|
|
|
|
|
MandatoryExpr(param);
|
|
|
|
|
|
|
|
|
|
if (!IsSamplerType(GetExpressionType(*node.parameters[0])))
|
|
|
|
|
throw AstError{ "First parameter must be a sampler" };
|
|
|
|
|
|
|
|
|
|
if (!IsVectorType(GetExpressionType(*node.parameters[1])))
|
|
|
|
|
throw AstError{ "Second parameter must be a vector" };
|
|
|
|
|
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Return type attribution
|
|
|
|
|
switch (node.intrinsic)
|
|
|
|
|
{
|
|
|
|
|
case IntrinsicType::CrossProduct:
|
|
|
|
|
{
|
|
|
|
|
const ExpressionType& type = GetExpressionType(*node.parameters.front());
|
|
|
|
|
if (type != ExpressionType{ VectorType{ 3, PrimitiveType::Float32 } })
|
|
|
|
|
throw AstError{ "CrossProduct only works with vec3<f32> expressions" };
|
|
|
|
|
|
|
|
|
|
node.cachedExpressionType = type;
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
case IntrinsicType::DotProduct:
|
|
|
|
|
case IntrinsicType::Length:
|
|
|
|
|
{
|
|
|
|
|
ExpressionType type = GetExpressionType(*node.parameters.front());
|
|
|
|
|
if (!IsVectorType(type))
|
|
|
|
|
throw AstError{ "DotProduct expects vector types" };
|
|
|
|
|
|
|
|
|
|
node.cachedExpressionType = std::get<VectorType>(type).type;
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
case IntrinsicType::Max:
|
|
|
|
|
case IntrinsicType::Min:
|
|
|
|
|
{
|
|
|
|
|
const ExpressionType& type = GetExpressionType(*node.parameters.front());
|
|
|
|
|
if (!IsPrimitiveType(type) && !IsVectorType(type))
|
|
|
|
|
throw AstError{ "max and min only work with primitive and vector types" };
|
|
|
|
|
|
|
|
|
|
if ((IsPrimitiveType(type) && std::get<PrimitiveType>(type) == PrimitiveType::Boolean) ||
|
|
|
|
|
(IsVectorType(type) && std::get<VectorType>(type).type == PrimitiveType::Boolean))
|
|
|
|
|
throw AstError{ "max and min do not work with booleans" };
|
|
|
|
|
|
|
|
|
|
node.cachedExpressionType = type;
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
case IntrinsicType::SampleTexture:
|
|
|
|
|
{
|
|
|
|
|
node.cachedExpressionType = VectorType{ 4, std::get<SamplerType>(GetExpressionType(*node.parameters.front())).sampledType };
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void SanitizeVisitor::TypeMustMatch(ExpressionPtr& left, ExpressionPtr& right)
|
|
|
|
|
{
|
|
|
|
|
return TypeMustMatch(GetExpressionType(*left), GetExpressionType(*right));
|
|
|
|
|
|