NazaraEngine/src/Nazara/Shader/Ast/SanitizeVisitor.cpp

1003 lines
29 KiB
C++

// Copyright (C) 2020 Jérôme Leclercq
// This file is part of the "Nazara Engine - Shader generator"
// For conditions of distribution and use, see copyright notice in Config.hpp
#include <Nazara/Shader/Ast/SanitizeVisitor.hpp>
#include <Nazara/Core/CallOnExit.hpp>
#include <Nazara/Core/StackArray.hpp>
#include <Nazara/Shader/ShaderBuilder.hpp>
#include <Nazara/Shader/Ast/AstUtils.hpp>
#include <stdexcept>
#include <unordered_set>
#include <Nazara/Shader/Debug.hpp>
namespace Nz::ShaderAst
{
namespace
{
struct AstError
{
std::string errMsg;
};
template<typename T, typename U>
std::unique_ptr<T> static_unique_pointer_cast(std::unique_ptr<U>&& ptr)
{
return std::unique_ptr<T>(static_cast<T*>(ptr.release()));
}
}
struct SanitizeVisitor::Context
{
Options options;
std::array<DeclareFunctionStatement*, ShaderStageTypeCount> entryFunctions = {};
std::unordered_set<std::string> declaredExternalVar;
std::unordered_set<unsigned int> usedBindingIndexes;
};
StatementPtr SanitizeVisitor::Sanitize(const StatementPtr& nodePtr, const Options& options, std::string* error)
{
StatementPtr clone;
Context currentContext;
currentContext.options = options;
m_context = &currentContext;
CallOnExit resetContext([&] { m_context = nullptr; });
PushScope(); //< Global scope
{
try
{
clone = AstCloner::Clone(nodePtr);
}
catch (const AstError& err)
{
if (!error)
throw std::runtime_error(err.errMsg);
*error = err.errMsg;
}
}
PopScope();
return clone;
}
const ExpressionType& SanitizeVisitor::CheckField(const ExpressionType& structType, const std::string* memberIdentifier, std::size_t remainingMembers, std::size_t* structIndices)
{
std::size_t structIndex = ResolveStruct(structType);
*structIndices++ = structIndex;
assert(structIndex < m_structs.size());
const StructDescription& s = m_structs[structIndex];
auto memberIt = std::find_if(s.members.begin(), s.members.end(), [&](const auto& field) { return field.name == memberIdentifier[0]; });
if (memberIt == s.members.end())
throw AstError{ "unknown field " + memberIdentifier[0] };
const auto& member = *memberIt;
if (remainingMembers > 1)
return CheckField(member.type, memberIdentifier + 1, remainingMembers - 1, structIndices);
else
return member.type;
}
ExpressionPtr SanitizeVisitor::Clone(AccessMemberIdentifierExpression& node)
{
auto structExpr = CloneExpression(MandatoryExpr(node.structExpr));
const ExpressionType& exprType = GetExpressionType(*structExpr);
if (IsVectorType(exprType))
{
const VectorType& swizzledVec = std::get<VectorType>(exprType);
// Swizzle expression
auto swizzle = std::make_unique<SwizzleExpression>();
swizzle->expression = std::move(structExpr);
// FIXME: Handle properly multiple identifiers (treat recursively)
if (node.memberIdentifiers.size() != 1)
throw AstError{ "invalid swizzle" };
const std::string& swizzleStr = node.memberIdentifiers.front();
if (swizzleStr.empty() || swizzleStr.size() > swizzle->components.size())
throw AstError{ "invalid swizzle" };
swizzle->componentCount = swizzleStr.size();
if (swizzle->componentCount > 1)
swizzle->cachedExpressionType = VectorType{ swizzle->componentCount, swizzledVec.type };
else
swizzle->cachedExpressionType = swizzledVec.type;
for (std::size_t i = 0; i < swizzle->componentCount; ++i)
{
switch (swizzleStr[i])
{
case 'r':
case 'x':
case 's':
swizzle->components[i] = SwizzleComponent::First;
break;
case 'g':
case 'y':
case 't':
swizzle->components[i] = SwizzleComponent::Second;
break;
case 'b':
case 'z':
case 'p':
swizzle->components[i] = SwizzleComponent::Third;
break;
case 'a':
case 'w':
case 'q':
swizzle->components[i] = SwizzleComponent::Fourth;
break;
}
}
return swizzle;
}
// Transform to AccessMemberIndexExpression
auto accessMemberIndex = std::make_unique<AccessMemberIndexExpression>();
accessMemberIndex->structExpr = std::move(structExpr);
StackArray<std::size_t> structIndices = NazaraStackArrayNoInit(std::size_t, node.memberIdentifiers.size());
accessMemberIndex->cachedExpressionType = ResolveType(CheckField(exprType, node.memberIdentifiers.data(), node.memberIdentifiers.size(), structIndices.data()));
accessMemberIndex->memberIndices.resize(node.memberIdentifiers.size());
for (std::size_t i = 0; i < node.memberIdentifiers.size(); ++i)
{
std::size_t structIndex = structIndices[i];
assert(structIndex < m_structs.size());
const StructDescription& structDesc = m_structs[structIndex];
auto it = std::find_if(structDesc.members.begin(), structDesc.members.end(), [&](const auto& member) { return member.name == node.memberIdentifiers[i]; });
assert(it != structDesc.members.end());
accessMemberIndex->memberIndices[i] = std::distance(structDesc.members.begin(), it);
}
return accessMemberIndex;
}
ExpressionPtr SanitizeVisitor::Clone(AssignExpression& node)
{
MandatoryExpr(node.left);
MandatoryExpr(node.right);
if (GetExpressionCategory(*node.left) != ExpressionCategory::LValue)
throw AstError{ "Assignation is only possible with a l-value" };
auto clone = static_unique_pointer_cast<AssignExpression>(AstCloner::Clone(node));
TypeMustMatch(clone->left, clone->right);
clone->cachedExpressionType = GetExpressionType(*clone->right);
return clone;
}
ExpressionPtr SanitizeVisitor::Clone(BinaryExpression& node)
{
auto clone = static_unique_pointer_cast<BinaryExpression>(AstCloner::Clone(node));
const ExpressionType& leftExprType = GetExpressionType(MandatoryExpr(clone->left));
if (!IsPrimitiveType(leftExprType) && !IsMatrixType(leftExprType) && !IsVectorType(leftExprType))
throw AstError{ "left expression type does not support binary operation" };
const ExpressionType& rightExprType = GetExpressionType(MandatoryExpr(clone->right));
if (!IsPrimitiveType(rightExprType) && !IsMatrixType(rightExprType) && !IsVectorType(rightExprType))
throw AstError{ "right expression type does not support binary operation" };
if (IsPrimitiveType(leftExprType))
{
PrimitiveType leftType = std::get<PrimitiveType>(leftExprType);
switch (clone->op)
{
case BinaryType::CompGe:
case BinaryType::CompGt:
case BinaryType::CompLe:
case BinaryType::CompLt:
if (leftType == PrimitiveType::Boolean)
throw AstError{ "this operation is not supported for booleans" };
TypeMustMatch(clone->left, clone->right);
clone->cachedExpressionType = PrimitiveType::Boolean;
break;
case BinaryType::Add:
case BinaryType::CompEq:
case BinaryType::CompNe:
case BinaryType::Subtract:
TypeMustMatch(clone->left, clone->right);
clone->cachedExpressionType = leftExprType;
break;
case BinaryType::Multiply:
case BinaryType::Divide:
{
switch (leftType)
{
case PrimitiveType::Float32:
case PrimitiveType::Int32:
case PrimitiveType::UInt32:
{
if (IsMatrixType(rightExprType))
{
TypeMustMatch(leftType, std::get<MatrixType>(rightExprType).type);
clone->cachedExpressionType = rightExprType;
}
else if (IsPrimitiveType(rightExprType))
{
TypeMustMatch(leftType, rightExprType);
clone->cachedExpressionType = leftExprType;
}
else if (IsVectorType(rightExprType))
{
TypeMustMatch(leftType, std::get<VectorType>(rightExprType).type);
clone->cachedExpressionType = rightExprType;
}
else
throw AstError{ "incompatible types" };
break;
}
case PrimitiveType::Boolean:
throw AstError{ "this operation is not supported for booleans" };
default:
throw AstError{ "incompatible types" };
}
}
}
}
else if (IsMatrixType(leftExprType))
{
const MatrixType& leftType = std::get<MatrixType>(leftExprType);
switch (clone->op)
{
case BinaryType::CompGe:
case BinaryType::CompGt:
case BinaryType::CompLe:
case BinaryType::CompLt:
case BinaryType::CompEq:
case BinaryType::CompNe:
TypeMustMatch(clone->left, clone->right);
clone->cachedExpressionType = PrimitiveType::Boolean;
break;
case BinaryType::Add:
case BinaryType::Subtract:
TypeMustMatch(clone->left, clone->right);
clone->cachedExpressionType = leftExprType;
break;
case BinaryType::Multiply:
case BinaryType::Divide:
{
if (IsMatrixType(rightExprType))
{
TypeMustMatch(leftExprType, rightExprType);
clone->cachedExpressionType = leftExprType; //< FIXME
}
else if (IsPrimitiveType(rightExprType))
{
TypeMustMatch(leftType.type, rightExprType);
clone->cachedExpressionType = leftExprType;
}
else if (IsVectorType(rightExprType))
{
const VectorType& rightType = std::get<VectorType>(rightExprType);
TypeMustMatch(leftType.type, rightType.type);
if (leftType.columnCount != rightType.componentCount)
throw AstError{ "incompatible types" };
clone->cachedExpressionType = rightExprType;
}
else
throw AstError{ "incompatible types" };
}
}
}
else if (IsVectorType(leftExprType))
{
const VectorType& leftType = std::get<VectorType>(leftExprType);
switch (clone->op)
{
case BinaryType::CompGe:
case BinaryType::CompGt:
case BinaryType::CompLe:
case BinaryType::CompLt:
case BinaryType::CompEq:
case BinaryType::CompNe:
TypeMustMatch(clone->left, clone->right);
clone->cachedExpressionType = PrimitiveType::Boolean;
break;
case BinaryType::Add:
case BinaryType::Subtract:
TypeMustMatch(clone->left, clone->right);
clone->cachedExpressionType = leftExprType;
break;
case BinaryType::Multiply:
case BinaryType::Divide:
{
if (IsPrimitiveType(rightExprType))
{
TypeMustMatch(leftType.type, rightExprType);
clone->cachedExpressionType = leftExprType;
}
else if (IsVectorType(rightExprType))
{
TypeMustMatch(leftType, rightExprType);
clone->cachedExpressionType = rightExprType;
}
else
throw AstError{ "incompatible types" };
}
}
}
return clone;
}
ExpressionPtr SanitizeVisitor::Clone(CastExpression& node)
{
auto clone = static_unique_pointer_cast<CastExpression>(AstCloner::Clone(node));
auto GetComponentCount = [](const ExpressionType& exprType) -> std::size_t
{
if (IsVectorType(exprType))
return std::get<VectorType>(exprType).componentCount;
else
{
assert(IsPrimitiveType(exprType));
return 1;
}
};
std::size_t componentCount = 0;
std::size_t requiredComponents = GetComponentCount(clone->targetType);
for (auto& exprPtr : clone->expressions)
{
if (!exprPtr)
break;
const ExpressionType& exprType = GetExpressionType(*exprPtr);
if (!IsPrimitiveType(exprType) && !IsVectorType(exprType))
throw AstError{ "incompatible type" };
componentCount += GetComponentCount(exprType);
}
if (componentCount != requiredComponents)
throw AstError{ "component count doesn't match required component count" };
clone->targetType = ResolveType(clone->targetType);
clone->cachedExpressionType = clone->targetType;
return clone;
}
ExpressionPtr SanitizeVisitor::Clone(ConditionalExpression& node)
{
MandatoryExpr(node.truePath);
MandatoryExpr(node.falsePath);
auto clone = static_unique_pointer_cast<ConditionalExpression>(AstCloner::Clone(node));
const ExpressionType& leftExprType = GetExpressionType(*clone->truePath);
if (leftExprType != GetExpressionType(*clone->falsePath))
throw AstError{ "true path type must match false path type" };
clone->cachedExpressionType = leftExprType;
return clone;
}
ExpressionPtr SanitizeVisitor::Clone(ConstantExpression& node)
{
auto clone = static_unique_pointer_cast<ConstantExpression>(AstCloner::Clone(node));
clone->cachedExpressionType = GetExpressionType(clone->value);
return clone;
}
ExpressionPtr SanitizeVisitor::Clone(IdentifierExpression& node)
{
assert(m_context);
const Identifier* identifier = FindIdentifier(node.identifier);
if (!identifier)
throw AstError{ "unknown identifier " + node.identifier };
if (!std::holds_alternative<Variable>(identifier->value))
throw AstError{ "expected variable identifier" };
const Variable& variable = std::get<Variable>(identifier->value);
// Replace IdentifierExpression by VariableExpression
auto varExpr = std::make_unique<VariableExpression>();
varExpr->cachedExpressionType = m_variables[variable.varIndex];
varExpr->variableId = variable.varIndex;
return varExpr;
}
ExpressionPtr SanitizeVisitor::Clone(IntrinsicExpression& node)
{
auto clone = static_unique_pointer_cast<IntrinsicExpression>(AstCloner::Clone(node));
// Parameter validation
switch (clone->intrinsic)
{
case IntrinsicType::CrossProduct:
case IntrinsicType::DotProduct:
{
if (clone->parameters.size() != 2)
throw AstError { "Expected two parameters" };
for (auto& param : clone->parameters)
MandatoryExpr(param);
const ExpressionType& type = GetExpressionType(*clone->parameters.front());
for (std::size_t i = 1; i < clone->parameters.size(); ++i)
{
if (type != GetExpressionType(*clone->parameters[i]))
throw AstError{ "All type must match" };
}
break;
}
case IntrinsicType::Length:
{
if (clone->parameters.size() != 1)
throw AstError{ "Expected only one parameters" };
for (auto& param : clone->parameters)
MandatoryExpr(param);
const ExpressionType& type = GetExpressionType(*clone->parameters.front());
if (!IsVectorType(type))
throw AstError{ "Expected a vector" };
break;
}
case IntrinsicType::SampleTexture:
{
if (clone->parameters.size() != 2)
throw AstError{ "Expected two parameters" };
for (auto& param : clone->parameters)
MandatoryExpr(param);
if (!IsSamplerType(GetExpressionType(*clone->parameters[0])))
throw AstError{ "First parameter must be a sampler" };
if (!IsVectorType(GetExpressionType(*clone->parameters[1])))
throw AstError{ "Second parameter must be a vector" };
break;
}
}
// Return type attribution
switch (clone->intrinsic)
{
case IntrinsicType::CrossProduct:
{
const ExpressionType& type = GetExpressionType(*clone->parameters.front());
if (type != ExpressionType{ VectorType{ 3, PrimitiveType::Float32 } })
throw AstError{ "CrossProduct only works with vec3<f32> expressions" };
clone->cachedExpressionType = type;
break;
}
case IntrinsicType::DotProduct:
case IntrinsicType::Length:
{
ExpressionType type = GetExpressionType(*clone->parameters.front());
if (!IsVectorType(type))
throw AstError{ "DotProduct expects vector types" };
clone->cachedExpressionType = std::get<VectorType>(type).type;
break;
}
case IntrinsicType::SampleTexture:
{
clone->cachedExpressionType = VectorType{ 4, std::get<SamplerType>(GetExpressionType(*clone->parameters.front())).sampledType };
break;
}
}
return clone;
}
ExpressionPtr SanitizeVisitor::Clone(SelectOptionExpression& node)
{
MandatoryExpr(node.truePath);
MandatoryExpr(node.falsePath);
auto condExpr = std::make_unique<ConditionalExpression>();
condExpr->truePath = CloneExpression(node.truePath);
condExpr->falsePath = CloneExpression(node.falsePath);
const Identifier* identifier = FindIdentifier(node.optionName);
if (!identifier)
throw AstError{ "unknown option " + node.optionName };
if (!std::holds_alternative<Option>(identifier->value))
throw AstError{ "expected option identifier" };
condExpr->optionIndex = std::get<Option>(identifier->value).optionIndex;
const ExpressionType& leftExprType = GetExpressionType(*condExpr->truePath);
if (leftExprType != GetExpressionType(*condExpr->falsePath))
throw AstError{ "true path type must match false path type" };
condExpr->cachedExpressionType = leftExprType;
return condExpr;
}
ExpressionPtr SanitizeVisitor::Clone(SwizzleExpression& node)
{
if (node.componentCount > 4)
throw AstError{ "Cannot swizzle more than four elements" };
MandatoryExpr(node.expression);
auto clone = static_unique_pointer_cast<SwizzleExpression>(AstCloner::Clone(node));
const ExpressionType& exprType = GetExpressionType(*clone->expression);
if (!IsPrimitiveType(exprType) && !IsVectorType(exprType))
throw AstError{ "Cannot swizzle this type" };
PrimitiveType baseType;
if (IsPrimitiveType(exprType))
baseType = std::get<PrimitiveType>(exprType);
else
baseType = std::get<VectorType>(exprType).type;
if (clone->componentCount > 1)
{
clone->cachedExpressionType = VectorType{
clone->componentCount,
baseType
};
}
else
clone->cachedExpressionType = baseType;
return clone;
}
ExpressionPtr SanitizeVisitor::Clone(UnaryExpression& node)
{
auto clone = static_unique_pointer_cast<UnaryExpression>(AstCloner::Clone(node));
const ExpressionType& exprType = GetExpressionType(MandatoryExpr(clone->expression));
if (!IsPrimitiveType(exprType))
throw AstError{ "unary expression operand type does not support unary operation" };
PrimitiveType primitiveType = std::get<PrimitiveType>(exprType);
switch (node.op)
{
case UnaryType::LogicalNot:
{
if (primitiveType != PrimitiveType::Boolean)
throw AstError{ "logical not is only supported on booleans" };
break;
}
case UnaryType::Minus:
case UnaryType::Plus:
{
if (primitiveType != PrimitiveType::Float32 && primitiveType != PrimitiveType::Int32 && primitiveType != PrimitiveType::UInt32)
throw AstError{ "plus and minus unary expressions are only supported on floating points and integers types" };
break;
}
}
clone->cachedExpressionType = primitiveType;
return clone;
}
StatementPtr SanitizeVisitor::Clone(BranchStatement& node)
{
auto clone = std::make_unique<BranchStatement>();
clone->condStatements.reserve(node.condStatements.size());
for (auto& cond : node.condStatements)
{
PushScope();
auto& condStatement = clone->condStatements.emplace_back();
condStatement.condition = CloneExpression(MandatoryExpr(cond.condition));
const ExpressionType& condType = GetExpressionType(*condStatement.condition);
if (!IsPrimitiveType(condType) || std::get<PrimitiveType>(condType) != PrimitiveType::Boolean)
throw AstError{ "branch expressions must resolve to boolean type" };
condStatement.statement = CloneStatement(MandatoryStatement(cond.statement));
PopScope();
}
if (node.elseStatement)
{
PushScope();
clone->elseStatement = CloneStatement(node.elseStatement);
PopScope();
}
return clone;
}
StatementPtr SanitizeVisitor::Clone(ConditionalStatement& node)
{
MandatoryStatement(node.statement);
PushScope();
auto clone = static_unique_pointer_cast<ConditionalStatement>(AstCloner::Clone(node));
PopScope();
return clone;
}
StatementPtr SanitizeVisitor::Clone(DeclareExternalStatement& node)
{
assert(m_context);
for (const auto& extVar : node.externalVars)
{
if (extVar.bindingIndex)
{
unsigned int bindingIndex = extVar.bindingIndex.value();
if (m_context->usedBindingIndexes.find(bindingIndex) != m_context->usedBindingIndexes.end())
throw AstError{ "Binding #" + std::to_string(bindingIndex) + " is already in use" };
m_context->usedBindingIndexes.insert(bindingIndex);
}
if (m_context->declaredExternalVar.find(extVar.name) != m_context->declaredExternalVar.end())
throw AstError{ "External variable " + extVar.name + " is already declared" };
m_context->declaredExternalVar.insert(extVar.name);
}
auto clone = static_unique_pointer_cast<DeclareExternalStatement>(AstCloner::Clone(node));
for (auto& extVar : clone->externalVars)
{
SanitizeIdentifier(extVar.name);
extVar.type = ResolveType(extVar.type);
ExpressionType varType;
if (IsUniformType(extVar.type))
varType = std::get<StructType>(std::get<UniformType>(extVar.type).containedType);
else if (IsSamplerType(extVar.type))
varType = extVar.type;
else
throw AstError{ "External variable " + extVar.name + " is of wrong type: only uniform and sampler are allowed in external blocks" };
std::size_t varIndex = RegisterVariable(extVar.name, std::move(varType));
if (!clone->varIndex)
clone->varIndex = varIndex; //< First external variable index is node variable index
}
return clone;
}
StatementPtr SanitizeVisitor::Clone(DeclareFunctionStatement& node)
{
if (node.entryStage)
{
ShaderStageType stageType = *node.entryStage;
if (m_context->entryFunctions[UnderlyingCast(stageType)])
throw AstError{ "the same entry type has been defined multiple times" };
m_context->entryFunctions[UnderlyingCast(stageType)] = &node;
if (node.parameters.size() > 1)
throw AstError{ "entry functions can either take one struct parameter or no parameter" };
}
auto clone = std::make_unique<DeclareFunctionStatement>();
clone->entryStage = node.entryStage;
clone->name = node.name;
clone->funcIndex = m_nextFuncIndex++;
clone->optionName = node.optionName;
clone->parameters = node.parameters;
clone->returnType = ResolveType(node.returnType);
SanitizeIdentifier(clone->name);
PushScope();
{
for (auto& parameter : clone->parameters)
{
parameter.type = ResolveType(parameter.type);
std::size_t varIndex = RegisterVariable(parameter.name, parameter.type);
if (!clone->varIndex)
clone->varIndex = varIndex; //< First parameter variable index is node variable index
SanitizeIdentifier(parameter.name);
}
clone->statements.reserve(node.statements.size());
for (auto& statement : node.statements)
clone->statements.push_back(CloneStatement(MandatoryStatement(statement)));
}
PopScope();
if (!clone->optionName.empty())
{
const Identifier* identifier = FindIdentifier(node.optionName);
if (!identifier)
throw AstError{ "unknown option " + node.optionName };
if (!std::holds_alternative<Option>(identifier->value))
throw AstError{ "expected option identifier" };
std::size_t optionIndex = std::get<Option>(identifier->value).optionIndex;
return ShaderBuilder::ConditionalStatement(optionIndex, std::move(clone));
}
return clone;
}
StatementPtr SanitizeVisitor::Clone(DeclareOptionStatement& node)
{
auto clone = static_unique_pointer_cast<DeclareOptionStatement>(AstCloner::Clone(node));
clone->optType = ResolveType(clone->optType);
if (clone->initialValue && clone->optType != GetExpressionType(*clone->initialValue))
throw AstError{ "option " + clone->optName + " initial expression must be of the same type than the option" };
clone->optIndex = RegisterOption(clone->optName, clone->optType);
if (m_context->options.removeOptionDeclaration)
return ShaderBuilder::NoOp();
return clone;
}
StatementPtr SanitizeVisitor::Clone(DeclareStructStatement& node)
{
std::unordered_set<std::string> declaredMembers;
for (auto& member : node.description.members)
{
if (declaredMembers.find(member.name) != declaredMembers.end())
throw AstError{ "struct member " + member.name + " found multiple time" };
declaredMembers.insert(member.name);
}
auto clone = static_unique_pointer_cast<DeclareStructStatement>(AstCloner::Clone(node));
for (auto& member : clone->description.members)
member.type = ResolveType(member.type);
clone->structIndex = RegisterStruct(clone->description.name, clone->description);
SanitizeIdentifier(clone->description.name);
return clone;
}
StatementPtr SanitizeVisitor::Clone(DeclareVariableStatement& node)
{
auto clone = static_unique_pointer_cast<DeclareVariableStatement>(AstCloner::Clone(node));
if (IsNoType(clone->varType))
{
if (!clone->initialExpression)
throw AstError{ "variable must either have a type or an initial value" };
clone->varType = ResolveType(GetExpressionType(*clone->initialExpression));
}
else
clone->varType = ResolveType(clone->varType);
clone->varIndex = RegisterVariable(clone->varName, clone->varType);
SanitizeIdentifier(clone->varName);
return clone;
}
StatementPtr SanitizeVisitor::Clone(ExpressionStatement& node)
{
MandatoryExpr(node.expression);
return AstCloner::Clone(node);
}
StatementPtr SanitizeVisitor::Clone(MultiStatement& node)
{
for (auto& statement : node.statements)
MandatoryStatement(statement);
PushScope();
auto clone = static_unique_pointer_cast<MultiStatement>(AstCloner::Clone(node));
PopScope();
return clone;
}
Expression& SanitizeVisitor::MandatoryExpr(ExpressionPtr& node)
{
if (!node)
throw AstError{ "Invalid expression" };
return *node;
}
Statement& SanitizeVisitor::MandatoryStatement(StatementPtr& node)
{
if (!node)
throw AstError{ "Invalid statement" };
return *node;
}
void SanitizeVisitor::PushScope()
{
m_scopeSizes.push_back(m_identifiersInScope.size());
}
void SanitizeVisitor::PopScope()
{
assert(!m_scopeSizes.empty());
m_identifiersInScope.resize(m_scopeSizes.back());
m_scopeSizes.pop_back();
}
std::size_t SanitizeVisitor::ResolveStruct(const ExpressionType& exprType)
{
return std::visit([&](auto&& arg) -> std::size_t
{
using T = std::decay_t<decltype(arg)>;
if constexpr (std::is_same_v<T, IdentifierType> || std::is_same_v<T, StructType> || std::is_same_v<T, UniformType>)
return ResolveStruct(arg);
else if constexpr (std::is_same_v<T, NoType> ||
std::is_same_v<T, PrimitiveType> ||
std::is_same_v<T, MatrixType> ||
std::is_same_v<T, SamplerType> ||
std::is_same_v<T, VectorType>)
{
throw AstError{ "expression is not a structure" };
}
else
static_assert(AlwaysFalse<T>::value, "non-exhaustive visitor");
}, exprType);
}
std::size_t SanitizeVisitor::ResolveStruct(const IdentifierType& identifierType)
{
const Identifier* identifier = FindIdentifier(identifierType.name);
if (!identifier)
throw AstError{ "unknown identifier " + identifierType.name };
if (!std::holds_alternative<Struct>(identifier->value))
throw AstError{ identifierType.name + " is not a struct" };
return std::get<Struct>(identifier->value).structIndex;
}
std::size_t SanitizeVisitor::ResolveStruct(const StructType& structType)
{
return structType.structIndex;
}
std::size_t SanitizeVisitor::ResolveStruct(const UniformType& uniformType)
{
return std::visit([&](auto&& arg) -> std::size_t
{
using T = std::decay_t<decltype(arg)>;
if constexpr (std::is_same_v<T, IdentifierType> || std::is_same_v<T, StructType>)
return ResolveStruct(arg);
else
static_assert(AlwaysFalse<T>::value, "non-exhaustive visitor");
}, uniformType.containedType);
}
ExpressionType SanitizeVisitor::ResolveType(const ExpressionType& exprType)
{
return std::visit([&](auto&& arg) -> ExpressionType
{
using T = std::decay_t<decltype(arg)>;
if constexpr (std::is_same_v<T, NoType> ||
std::is_same_v<T, PrimitiveType> ||
std::is_same_v<T, MatrixType> ||
std::is_same_v<T, SamplerType> ||
std::is_same_v<T, StructType> ||
std::is_same_v<T, VectorType>)
{
return exprType;
}
else if constexpr (std::is_same_v<T, IdentifierType>)
{
const Identifier* identifier = FindIdentifier(arg.name);
if (!identifier)
throw AstError{ "unknown identifier " + arg.name };
if (!std::holds_alternative<Struct>(identifier->value))
throw AstError{ "expected type identifier" };
return StructType{ std::get<Struct>(identifier->value).structIndex };
}
else if constexpr (std::is_same_v<T, UniformType>)
{
return std::visit([&](auto&& containedArg)
{
ExpressionType resolvedType = ResolveType(containedArg);
assert(std::holds_alternative<StructType>(resolvedType));
return UniformType{ std::get<StructType>(resolvedType) };
}, arg.containedType);
}
else
static_assert(AlwaysFalse<T>::value, "non-exhaustive visitor");
}, exprType);
}
void SanitizeVisitor::SanitizeIdentifier(std::string& identifier)
{
// Append _ until the identifier is no longer found
while (m_context->options.reservedIdentifiers.find(identifier) != m_context->options.reservedIdentifiers.end())
{
do
{
identifier += "_";
}
while (FindIdentifier(identifier) != nullptr);
}
}
void SanitizeVisitor::TypeMustMatch(ExpressionPtr& left, ExpressionPtr& right)
{
return TypeMustMatch(GetExpressionType(*left), GetExpressionType(*right));
}
void SanitizeVisitor::TypeMustMatch(const ExpressionType& left, const ExpressionType& right)
{
if (left != right)
throw AstError{ "Left expression type must match right expression type" };
}
}