|
|
|
|
@@ -3,6 +3,7 @@
|
|
|
|
|
// For conditions of distribution and use, see copyright notice in Config.hpp
|
|
|
|
|
|
|
|
|
|
#include <Nazara/Shader/Ast/SanitizeVisitor.hpp>
|
|
|
|
|
#include <Nazara/Core/Algorithm.hpp>
|
|
|
|
|
#include <Nazara/Core/CallOnExit.hpp>
|
|
|
|
|
#include <Nazara/Core/StackArray.hpp>
|
|
|
|
|
#include <Nazara/Shader/ShaderBuilder.hpp>
|
|
|
|
|
@@ -25,7 +26,7 @@ namespace Nz::ShaderAst
|
|
|
|
|
template<typename T, typename U>
|
|
|
|
|
std::unique_ptr<T> static_unique_pointer_cast(std::unique_ptr<U>&& ptr)
|
|
|
|
|
{
|
|
|
|
|
return std::unique_ptr<T>(static_cast<T*>(ptr.release()));
|
|
|
|
|
return std::unique_ptr<T>(SafeCast<T*>(ptr.release()));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@@ -52,6 +53,7 @@ namespace Nz::ShaderAst
|
|
|
|
|
std::vector<ExpressionType> variableTypes;
|
|
|
|
|
std::vector<std::size_t> scopeSizes;
|
|
|
|
|
CurrentFunctionData* currentFunction = nullptr;
|
|
|
|
|
std::vector<StatementPtr>* currentStatementList = nullptr;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
StatementPtr SanitizeVisitor::Sanitize(Statement& statement, const Options& options, std::string* error)
|
|
|
|
|
@@ -117,16 +119,46 @@ namespace Nz::ShaderAst
|
|
|
|
|
|
|
|
|
|
return clone;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
UInt32 SanitizeVisitor::ToSwizzleIndex(char c)
|
|
|
|
|
{
|
|
|
|
|
switch (c)
|
|
|
|
|
{
|
|
|
|
|
case 'r':
|
|
|
|
|
case 'x':
|
|
|
|
|
case 's':
|
|
|
|
|
return 0u;
|
|
|
|
|
|
|
|
|
|
case 'g':
|
|
|
|
|
case 'y':
|
|
|
|
|
case 't':
|
|
|
|
|
return 1u;
|
|
|
|
|
|
|
|
|
|
case 'b':
|
|
|
|
|
case 'z':
|
|
|
|
|
case 'p':
|
|
|
|
|
return 2u;
|
|
|
|
|
|
|
|
|
|
case 'a':
|
|
|
|
|
case 'w':
|
|
|
|
|
case 'q':
|
|
|
|
|
return 3u;
|
|
|
|
|
|
|
|
|
|
default:
|
|
|
|
|
throw AstError{ "unexpected character '" + std::string(&c, 1) + "' on swizzle " };
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ExpressionPtr SanitizeVisitor::Clone(AccessIdentifierExpression& node)
|
|
|
|
|
{
|
|
|
|
|
if (node.identifiers.empty())
|
|
|
|
|
throw AstError{ "AccessIdentifierExpression must have at least one identifier" };
|
|
|
|
|
throw AstError{ "accessIdentifierExpression must have at least one identifier" };
|
|
|
|
|
|
|
|
|
|
ExpressionPtr indexedExpr = CloneExpression(MandatoryExpr(node.expr));
|
|
|
|
|
for (std::size_t i = 0; i < node.identifiers.size(); ++i)
|
|
|
|
|
for (const std::string& identifier : node.identifiers)
|
|
|
|
|
{
|
|
|
|
|
const std::string& identifier = node.identifiers[i];
|
|
|
|
|
if (identifier.empty())
|
|
|
|
|
throw AstError{ "empty identifier" };
|
|
|
|
|
|
|
|
|
|
const ExpressionType& exprType = GetExpressionType(*indexedExpr);
|
|
|
|
|
if (IsStructType(exprType))
|
|
|
|
|
@@ -171,62 +203,64 @@ namespace Nz::ShaderAst
|
|
|
|
|
accessIndexPtr->indices.push_back(ShaderBuilder::Constant(fieldIndex));
|
|
|
|
|
accessIndexPtr->cachedExpressionType = ResolveType(fieldPtr->type);
|
|
|
|
|
}
|
|
|
|
|
else if (IsVectorType(exprType))
|
|
|
|
|
else if (IsPrimitiveType(exprType) || IsVectorType(exprType))
|
|
|
|
|
{
|
|
|
|
|
// Swizzle expression
|
|
|
|
|
const VectorType& swizzledVec = std::get<VectorType>(exprType);
|
|
|
|
|
PrimitiveType baseType;
|
|
|
|
|
std::size_t componentCount;
|
|
|
|
|
|
|
|
|
|
auto swizzle = std::make_unique<SwizzleExpression>();
|
|
|
|
|
swizzle->expression = std::move(indexedExpr);
|
|
|
|
|
|
|
|
|
|
if (node.identifiers.size() - i != 1)
|
|
|
|
|
throw AstError{ "invalid swizzle" };
|
|
|
|
|
|
|
|
|
|
const std::string& swizzleStr = node.identifiers[i];
|
|
|
|
|
if (swizzleStr.empty() || swizzleStr.size() > swizzle->components.size())
|
|
|
|
|
throw AstError{ "invalid swizzle" };
|
|
|
|
|
|
|
|
|
|
swizzle->componentCount = swizzleStr.size();
|
|
|
|
|
|
|
|
|
|
if (swizzle->componentCount > 1)
|
|
|
|
|
swizzle->cachedExpressionType = VectorType{ swizzle->componentCount, swizzledVec.type };
|
|
|
|
|
else
|
|
|
|
|
swizzle->cachedExpressionType = swizzledVec.type;
|
|
|
|
|
|
|
|
|
|
for (std::size_t j = 0; j < swizzle->componentCount; ++j)
|
|
|
|
|
if (IsVectorType(exprType))
|
|
|
|
|
{
|
|
|
|
|
switch (swizzleStr[j])
|
|
|
|
|
{
|
|
|
|
|
case 'r':
|
|
|
|
|
case 'x':
|
|
|
|
|
case 's':
|
|
|
|
|
swizzle->components[j] = 0u;
|
|
|
|
|
break;
|
|
|
|
|
|
|
|
|
|
case 'g':
|
|
|
|
|
case 'y':
|
|
|
|
|
case 't':
|
|
|
|
|
swizzle->components[j] = 1u;
|
|
|
|
|
break;
|
|
|
|
|
|
|
|
|
|
case 'b':
|
|
|
|
|
case 'z':
|
|
|
|
|
case 'p':
|
|
|
|
|
swizzle->components[j] = 2u;
|
|
|
|
|
break;
|
|
|
|
|
|
|
|
|
|
case 'a':
|
|
|
|
|
case 'w':
|
|
|
|
|
case 'q':
|
|
|
|
|
swizzle->components[j] = 3u;
|
|
|
|
|
break;
|
|
|
|
|
|
|
|
|
|
default:
|
|
|
|
|
throw AstError{ "unexpected character '" + std::string(swizzleStr) + "' on swizzle " };
|
|
|
|
|
}
|
|
|
|
|
const VectorType& swizzledVec = std::get<VectorType>(exprType);
|
|
|
|
|
baseType = swizzledVec.type;
|
|
|
|
|
componentCount = swizzledVec.componentCount;
|
|
|
|
|
}
|
|
|
|
|
else
|
|
|
|
|
{
|
|
|
|
|
baseType = std::get<PrimitiveType>(exprType);
|
|
|
|
|
componentCount = 1;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
indexedExpr = std::move(swizzle);
|
|
|
|
|
std::size_t swizzleComponentCount = identifier.size();
|
|
|
|
|
if (swizzleComponentCount > 4)
|
|
|
|
|
throw AstError{ "cannot swizzle more than four elements" };
|
|
|
|
|
|
|
|
|
|
if (m_context->options.removeScalarSwizzling && IsPrimitiveType(exprType))
|
|
|
|
|
{
|
|
|
|
|
for (std::size_t j = 0; j < swizzleComponentCount; ++j)
|
|
|
|
|
{
|
|
|
|
|
if (ToSwizzleIndex(identifier[j]) != 0)
|
|
|
|
|
throw AstError{ "invalid swizzle" };
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (swizzleComponentCount == 1)
|
|
|
|
|
continue; //< ignore this swizzle (a.x == a)
|
|
|
|
|
|
|
|
|
|
// Use a Cast expression to replace swizzle
|
|
|
|
|
indexedExpr = CacheResult(std::move(indexedExpr)); //< Since we are going to use a value multiple times, cache it if required
|
|
|
|
|
|
|
|
|
|
auto cast = std::make_unique<CastExpression>();
|
|
|
|
|
cast->targetType = VectorType{ swizzleComponentCount, baseType };
|
|
|
|
|
for (std::size_t j = 0; j < swizzleComponentCount; ++j)
|
|
|
|
|
cast->expressions[j] = CloneExpression(indexedExpr);
|
|
|
|
|
|
|
|
|
|
Validate(*cast);
|
|
|
|
|
|
|
|
|
|
indexedExpr = std::move(cast);
|
|
|
|
|
}
|
|
|
|
|
else
|
|
|
|
|
{
|
|
|
|
|
auto swizzle = std::make_unique<SwizzleExpression>();
|
|
|
|
|
swizzle->expression = std::move(indexedExpr);
|
|
|
|
|
|
|
|
|
|
swizzle->componentCount = swizzleComponentCount;
|
|
|
|
|
for (std::size_t j = 0; j < swizzleComponentCount; ++j)
|
|
|
|
|
swizzle->components[j] = ToSwizzleIndex(identifier[j]);
|
|
|
|
|
|
|
|
|
|
Validate(*swizzle);
|
|
|
|
|
|
|
|
|
|
indexedExpr = std::move(swizzle);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
else
|
|
|
|
|
throw AstError{ "unexpected type (only struct and vectors can be indexed with identifiers)" }; //< TODO: Add support for arrays
|
|
|
|
|
@@ -252,11 +286,11 @@ namespace Nz::ShaderAst
|
|
|
|
|
MandatoryExpr(node.left);
|
|
|
|
|
MandatoryExpr(node.right);
|
|
|
|
|
|
|
|
|
|
if (GetExpressionCategory(*node.left) != ExpressionCategory::LValue)
|
|
|
|
|
throw AstError{ "Assignation is only possible with a l-value" };
|
|
|
|
|
|
|
|
|
|
auto clone = static_unique_pointer_cast<AssignExpression>(AstCloner::Clone(node));
|
|
|
|
|
|
|
|
|
|
if (GetExpressionCategory(*clone->left) != ExpressionCategory::LValue)
|
|
|
|
|
throw AstError{ "Assignation is only possible with a l-value" };
|
|
|
|
|
|
|
|
|
|
std::optional<BinaryType> binaryType;
|
|
|
|
|
switch (clone->op)
|
|
|
|
|
{
|
|
|
|
|
@@ -366,48 +400,7 @@ namespace Nz::ShaderAst
|
|
|
|
|
ExpressionPtr SanitizeVisitor::Clone(CastExpression& node)
|
|
|
|
|
{
|
|
|
|
|
auto clone = static_unique_pointer_cast<CastExpression>(AstCloner::Clone(node));
|
|
|
|
|
|
|
|
|
|
clone->cachedExpressionType = clone->targetType;
|
|
|
|
|
clone->targetType = ResolveType(clone->targetType);
|
|
|
|
|
|
|
|
|
|
//FIXME: Make proper rules
|
|
|
|
|
if (IsMatrixType(clone->targetType) && clone->expressions.front())
|
|
|
|
|
{
|
|
|
|
|
const ExpressionType& exprType = GetExpressionType(*clone->expressions.front());
|
|
|
|
|
if (IsMatrixType(exprType) && !clone->expressions[1])
|
|
|
|
|
{
|
|
|
|
|
return clone;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto GetComponentCount = [](const ExpressionType& exprType) -> std::size_t
|
|
|
|
|
{
|
|
|
|
|
if (IsVectorType(exprType))
|
|
|
|
|
return std::get<VectorType>(exprType).componentCount;
|
|
|
|
|
else
|
|
|
|
|
{
|
|
|
|
|
assert(IsPrimitiveType(exprType));
|
|
|
|
|
return 1;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
std::size_t componentCount = 0;
|
|
|
|
|
std::size_t requiredComponents = GetComponentCount(clone->targetType);
|
|
|
|
|
|
|
|
|
|
for (auto& exprPtr : clone->expressions)
|
|
|
|
|
{
|
|
|
|
|
if (!exprPtr)
|
|
|
|
|
break;
|
|
|
|
|
|
|
|
|
|
const ExpressionType& exprType = GetExpressionType(*exprPtr);
|
|
|
|
|
if (!IsPrimitiveType(exprType) && !IsVectorType(exprType))
|
|
|
|
|
throw AstError{ "incompatible type" };
|
|
|
|
|
|
|
|
|
|
componentCount += GetComponentCount(exprType);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (componentCount != requiredComponents)
|
|
|
|
|
throw AstError{ "component count doesn't match required component count" };
|
|
|
|
|
Validate(*clone);
|
|
|
|
|
|
|
|
|
|
return clone;
|
|
|
|
|
}
|
|
|
|
|
@@ -495,38 +488,8 @@ namespace Nz::ShaderAst
|
|
|
|
|
|
|
|
|
|
ExpressionPtr SanitizeVisitor::Clone(SwizzleExpression& node)
|
|
|
|
|
{
|
|
|
|
|
if (node.componentCount > 4)
|
|
|
|
|
throw AstError{ "Cannot swizzle more than four elements" };
|
|
|
|
|
|
|
|
|
|
for (UInt32 swizzleIndex : node.components)
|
|
|
|
|
{
|
|
|
|
|
if (swizzleIndex >= 4)
|
|
|
|
|
throw AstError{ "invalid swizzle" };
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
MandatoryExpr(node.expression);
|
|
|
|
|
|
|
|
|
|
auto clone = static_unique_pointer_cast<SwizzleExpression>(AstCloner::Clone(node));
|
|
|
|
|
|
|
|
|
|
const ExpressionType& exprType = GetExpressionType(*clone->expression);
|
|
|
|
|
if (!IsPrimitiveType(exprType) && !IsVectorType(exprType))
|
|
|
|
|
throw AstError{ "Cannot swizzle this type" };
|
|
|
|
|
|
|
|
|
|
PrimitiveType baseType;
|
|
|
|
|
if (IsPrimitiveType(exprType))
|
|
|
|
|
baseType = std::get<PrimitiveType>(exprType);
|
|
|
|
|
else
|
|
|
|
|
baseType = std::get<VectorType>(exprType).type;
|
|
|
|
|
|
|
|
|
|
if (clone->componentCount > 1)
|
|
|
|
|
{
|
|
|
|
|
clone->cachedExpressionType = VectorType{
|
|
|
|
|
clone->componentCount,
|
|
|
|
|
baseType
|
|
|
|
|
};
|
|
|
|
|
}
|
|
|
|
|
else
|
|
|
|
|
clone->cachedExpressionType = baseType;
|
|
|
|
|
Validate(*clone);
|
|
|
|
|
|
|
|
|
|
return clone;
|
|
|
|
|
}
|
|
|
|
|
@@ -534,50 +497,17 @@ namespace Nz::ShaderAst
|
|
|
|
|
ExpressionPtr SanitizeVisitor::Clone(UnaryExpression& node)
|
|
|
|
|
{
|
|
|
|
|
auto clone = static_unique_pointer_cast<UnaryExpression>(AstCloner::Clone(node));
|
|
|
|
|
|
|
|
|
|
const ExpressionType& exprType = GetExpressionType(MandatoryExpr(clone->expression));
|
|
|
|
|
|
|
|
|
|
switch (node.op)
|
|
|
|
|
{
|
|
|
|
|
case UnaryType::LogicalNot:
|
|
|
|
|
{
|
|
|
|
|
if (exprType != ExpressionType(PrimitiveType::Boolean))
|
|
|
|
|
throw AstError{ "logical not is only supported on booleans" };
|
|
|
|
|
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
case UnaryType::Minus:
|
|
|
|
|
case UnaryType::Plus:
|
|
|
|
|
{
|
|
|
|
|
ShaderAst::PrimitiveType basicType;
|
|
|
|
|
if (IsPrimitiveType(exprType))
|
|
|
|
|
basicType = std::get<ShaderAst::PrimitiveType>(exprType);
|
|
|
|
|
else if (IsVectorType(exprType))
|
|
|
|
|
basicType = std::get<ShaderAst::VectorType>(exprType).type;
|
|
|
|
|
else
|
|
|
|
|
throw AstError{ "plus and minus unary expressions are only supported on primitive/vectors types" };
|
|
|
|
|
|
|
|
|
|
if (basicType != PrimitiveType::Float32 && basicType != PrimitiveType::Int32 && basicType != PrimitiveType::UInt32)
|
|
|
|
|
throw AstError{ "plus and minus unary expressions are only supported on floating points and integers types" };
|
|
|
|
|
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
clone->cachedExpressionType = exprType;
|
|
|
|
|
Validate(*clone);
|
|
|
|
|
|
|
|
|
|
return clone;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ExpressionPtr SanitizeVisitor::Clone(VariableExpression& node)
|
|
|
|
|
{
|
|
|
|
|
if (node.variableId >= m_context->variableTypes.size())
|
|
|
|
|
throw AstError{ "invalid constant index " + std::to_string(node.variableId) };
|
|
|
|
|
auto clone = static_unique_pointer_cast<VariableExpression>(AstCloner::Clone(node));
|
|
|
|
|
Validate(*clone);
|
|
|
|
|
|
|
|
|
|
node.cachedExpressionType = m_context->variableTypes[node.variableId];
|
|
|
|
|
|
|
|
|
|
return AstCloner::Clone(node);
|
|
|
|
|
return clone;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
StatementPtr SanitizeVisitor::Clone(BranchStatement& node)
|
|
|
|
|
@@ -778,6 +708,9 @@ namespace Nz::ShaderAst
|
|
|
|
|
|
|
|
|
|
m_context->currentFunction = &tempFuncData;
|
|
|
|
|
|
|
|
|
|
std::vector<StatementPtr>* previousList = m_context->currentStatementList;
|
|
|
|
|
m_context->currentStatementList = &clone->statements;
|
|
|
|
|
|
|
|
|
|
PushScope();
|
|
|
|
|
{
|
|
|
|
|
for (auto& parameter : clone->parameters)
|
|
|
|
|
@@ -796,6 +729,7 @@ namespace Nz::ShaderAst
|
|
|
|
|
}
|
|
|
|
|
PopScope();
|
|
|
|
|
|
|
|
|
|
m_context->currentStatementList = previousList;
|
|
|
|
|
m_context->currentFunction = nullptr;
|
|
|
|
|
|
|
|
|
|
if (clone->earlyFragmentTests.HasValue() && clone->earlyFragmentTests.GetResultingValue())
|
|
|
|
|
@@ -897,33 +831,7 @@ namespace Nz::ShaderAst
|
|
|
|
|
throw AstError{ "global variables outside of external blocks are forbidden" };
|
|
|
|
|
|
|
|
|
|
auto clone = static_unique_pointer_cast<DeclareVariableStatement>(AstCloner::Clone(node));
|
|
|
|
|
if (IsNoType(clone->varType))
|
|
|
|
|
{
|
|
|
|
|
if (!clone->initialExpression)
|
|
|
|
|
throw AstError{ "variable must either have a type or an initial value" };
|
|
|
|
|
|
|
|
|
|
clone->varType = ResolveType(GetExpressionType(*clone->initialExpression));
|
|
|
|
|
}
|
|
|
|
|
else
|
|
|
|
|
clone->varType = ResolveType(clone->varType);
|
|
|
|
|
|
|
|
|
|
if (m_context->options.makeVariableNameUnique && FindIdentifier(clone->varName) != nullptr)
|
|
|
|
|
{
|
|
|
|
|
// Try to append _X to the variable name until by incrementing X
|
|
|
|
|
unsigned int cloneIndex = 2;
|
|
|
|
|
std::string candidateName;
|
|
|
|
|
do
|
|
|
|
|
{
|
|
|
|
|
candidateName = clone->varName + "_" + std::to_string(cloneIndex++);
|
|
|
|
|
}
|
|
|
|
|
while (FindIdentifier(candidateName) != nullptr);
|
|
|
|
|
|
|
|
|
|
clone->varName = std::move(candidateName);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
clone->varIndex = RegisterVariable(clone->varName, clone->varType);
|
|
|
|
|
|
|
|
|
|
SanitizeIdentifier(clone->varName);
|
|
|
|
|
Validate(*clone);
|
|
|
|
|
|
|
|
|
|
return clone;
|
|
|
|
|
}
|
|
|
|
|
@@ -947,12 +855,18 @@ namespace Nz::ShaderAst
|
|
|
|
|
|
|
|
|
|
StatementPtr SanitizeVisitor::Clone(MultiStatement& node)
|
|
|
|
|
{
|
|
|
|
|
for (auto& statement : node.statements)
|
|
|
|
|
MandatoryStatement(statement);
|
|
|
|
|
|
|
|
|
|
PushScope();
|
|
|
|
|
|
|
|
|
|
auto clone = static_unique_pointer_cast<MultiStatement>(AstCloner::Clone(node));
|
|
|
|
|
auto clone = std::make_unique<MultiStatement>();
|
|
|
|
|
clone->statements.reserve(node.statements.size());
|
|
|
|
|
|
|
|
|
|
std::vector<StatementPtr>* previousList = m_context->currentStatementList;
|
|
|
|
|
m_context->currentStatementList = &clone->statements;
|
|
|
|
|
|
|
|
|
|
for (auto& statement : node.statements)
|
|
|
|
|
clone->statements.push_back(AstCloner::Clone(MandatoryStatement(statement)));
|
|
|
|
|
|
|
|
|
|
m_context->currentStatementList = previousList;
|
|
|
|
|
|
|
|
|
|
PopScope();
|
|
|
|
|
|
|
|
|
|
@@ -1009,6 +923,25 @@ namespace Nz::ShaderAst
|
|
|
|
|
m_context->scopeSizes.pop_back();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ExpressionPtr SanitizeVisitor::CacheResult(ExpressionPtr expression)
|
|
|
|
|
{
|
|
|
|
|
// No need to cache LValues (variables/constants) (TODO: Improve this, as constants doens't need to be cached as well)
|
|
|
|
|
if (GetExpressionCategory(*expression) == ExpressionCategory::LValue)
|
|
|
|
|
return expression;
|
|
|
|
|
|
|
|
|
|
assert(m_context->currentStatementList);
|
|
|
|
|
|
|
|
|
|
auto variableDeclaration = ShaderBuilder::DeclareVariable("cachedResult", std::move(expression)); //< Validation will prevent name-clash if required
|
|
|
|
|
Validate(*variableDeclaration);
|
|
|
|
|
|
|
|
|
|
auto varExpr = std::make_unique<VariableExpression>();
|
|
|
|
|
varExpr->variableId = *variableDeclaration->varIndex;
|
|
|
|
|
|
|
|
|
|
m_context->currentStatementList->push_back(std::move(variableDeclaration));
|
|
|
|
|
|
|
|
|
|
return varExpr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template<typename T>
|
|
|
|
|
const T& SanitizeVisitor::ComputeAttributeValue(AttributeValue<T>& attribute)
|
|
|
|
|
{
|
|
|
|
|
@@ -1381,6 +1314,83 @@ namespace Nz::ShaderAst
|
|
|
|
|
node.cachedExpressionType = referenceDeclaration->returnType;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void SanitizeVisitor::Validate(CastExpression& node)
|
|
|
|
|
{
|
|
|
|
|
node.cachedExpressionType = node.targetType;
|
|
|
|
|
node.targetType = ResolveType(node.targetType);
|
|
|
|
|
|
|
|
|
|
// Allow casting a matrix to itself (wtf?)
|
|
|
|
|
// FIXME: Make proper rules
|
|
|
|
|
if (IsMatrixType(node.targetType) && node.expressions.front())
|
|
|
|
|
{
|
|
|
|
|
const ExpressionType& exprType = GetExpressionType(*node.expressions.front());
|
|
|
|
|
if (IsMatrixType(exprType) && !node.expressions[1])
|
|
|
|
|
{
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto GetComponentCount = [](const ExpressionType& exprType) -> std::size_t
|
|
|
|
|
{
|
|
|
|
|
if (IsVectorType(exprType))
|
|
|
|
|
return std::get<VectorType>(exprType).componentCount;
|
|
|
|
|
else
|
|
|
|
|
{
|
|
|
|
|
assert(IsPrimitiveType(exprType));
|
|
|
|
|
return 1;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
std::size_t componentCount = 0;
|
|
|
|
|
std::size_t requiredComponents = GetComponentCount(node.targetType);
|
|
|
|
|
|
|
|
|
|
for (auto& exprPtr : node.expressions)
|
|
|
|
|
{
|
|
|
|
|
if (!exprPtr)
|
|
|
|
|
break;
|
|
|
|
|
|
|
|
|
|
const ExpressionType& exprType = GetExpressionType(*exprPtr);
|
|
|
|
|
if (!IsPrimitiveType(exprType) && !IsVectorType(exprType))
|
|
|
|
|
throw AstError{ "incompatible type" };
|
|
|
|
|
|
|
|
|
|
componentCount += GetComponentCount(exprType);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (componentCount != requiredComponents)
|
|
|
|
|
throw AstError{ "component count doesn't match required component count" };
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void SanitizeVisitor::Validate(DeclareVariableStatement& node)
|
|
|
|
|
{
|
|
|
|
|
if (IsNoType(node.varType))
|
|
|
|
|
{
|
|
|
|
|
if (!node.initialExpression)
|
|
|
|
|
throw AstError{ "variable must either have a type or an initial value" };
|
|
|
|
|
|
|
|
|
|
node.varType = ResolveType(GetExpressionType(*node.initialExpression));
|
|
|
|
|
}
|
|
|
|
|
else
|
|
|
|
|
node.varType = ResolveType(node.varType);
|
|
|
|
|
|
|
|
|
|
if (m_context->options.makeVariableNameUnique && FindIdentifier(node.varName) != nullptr)
|
|
|
|
|
{
|
|
|
|
|
// Try to make variable name unique by appending _X to its name (incrementing X until it's unique) to the variable name until by incrementing X
|
|
|
|
|
unsigned int cloneIndex = 2;
|
|
|
|
|
std::string candidateName;
|
|
|
|
|
do
|
|
|
|
|
{
|
|
|
|
|
candidateName = node.varName + "_" + std::to_string(cloneIndex++);
|
|
|
|
|
}
|
|
|
|
|
while (FindIdentifier(candidateName) != nullptr);
|
|
|
|
|
|
|
|
|
|
node.varName = std::move(candidateName);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
node.varIndex = RegisterVariable(node.varName, node.varType);
|
|
|
|
|
|
|
|
|
|
SanitizeIdentifier(node.varName);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void SanitizeVisitor::Validate(IntrinsicExpression& node)
|
|
|
|
|
{
|
|
|
|
|
// Parameter validation
|
|
|
|
|
@@ -1510,6 +1520,90 @@ namespace Nz::ShaderAst
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void SanitizeVisitor::Validate(SwizzleExpression& node)
|
|
|
|
|
{
|
|
|
|
|
MandatoryExpr(node.expression);
|
|
|
|
|
const ExpressionType& exprType = GetExpressionType(*node.expression);
|
|
|
|
|
if (!IsPrimitiveType(exprType) && !IsVectorType(exprType))
|
|
|
|
|
throw AstError{ "Cannot swizzle this type" };
|
|
|
|
|
|
|
|
|
|
PrimitiveType baseType;
|
|
|
|
|
std::size_t componentCount;
|
|
|
|
|
if (IsPrimitiveType(exprType))
|
|
|
|
|
{
|
|
|
|
|
baseType = std::get<PrimitiveType>(exprType);
|
|
|
|
|
componentCount = 1;
|
|
|
|
|
}
|
|
|
|
|
else
|
|
|
|
|
{
|
|
|
|
|
const VectorType& vecType = std::get<VectorType>(exprType);
|
|
|
|
|
baseType = vecType.type;
|
|
|
|
|
componentCount = vecType.componentCount;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (node.componentCount > 4)
|
|
|
|
|
throw AstError{ "cannot swizzle more than four elements" };
|
|
|
|
|
|
|
|
|
|
for (UInt32 swizzleIndex : node.components)
|
|
|
|
|
{
|
|
|
|
|
if (swizzleIndex >= componentCount)
|
|
|
|
|
throw AstError{ "invalid swizzle" };
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (node.componentCount > 1)
|
|
|
|
|
{
|
|
|
|
|
node.cachedExpressionType = VectorType{
|
|
|
|
|
node.componentCount,
|
|
|
|
|
baseType
|
|
|
|
|
};
|
|
|
|
|
}
|
|
|
|
|
else
|
|
|
|
|
node.cachedExpressionType = baseType;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void SanitizeVisitor::Validate(UnaryExpression& node)
|
|
|
|
|
{
|
|
|
|
|
const ExpressionType& exprType = GetExpressionType(MandatoryExpr(node.expression));
|
|
|
|
|
|
|
|
|
|
switch (node.op)
|
|
|
|
|
{
|
|
|
|
|
case UnaryType::LogicalNot:
|
|
|
|
|
{
|
|
|
|
|
if (exprType != ExpressionType(PrimitiveType::Boolean))
|
|
|
|
|
throw AstError{ "logical not is only supported on booleans" };
|
|
|
|
|
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
case UnaryType::Minus:
|
|
|
|
|
case UnaryType::Plus:
|
|
|
|
|
{
|
|
|
|
|
ShaderAst::PrimitiveType basicType;
|
|
|
|
|
if (IsPrimitiveType(exprType))
|
|
|
|
|
basicType = std::get<ShaderAst::PrimitiveType>(exprType);
|
|
|
|
|
else if (IsVectorType(exprType))
|
|
|
|
|
basicType = std::get<ShaderAst::VectorType>(exprType).type;
|
|
|
|
|
else
|
|
|
|
|
throw AstError{ "plus and minus unary expressions are only supported on primitive/vectors types" };
|
|
|
|
|
|
|
|
|
|
if (basicType != PrimitiveType::Float32 && basicType != PrimitiveType::Int32 && basicType != PrimitiveType::UInt32)
|
|
|
|
|
throw AstError{ "plus and minus unary expressions are only supported on floating points and integers types" };
|
|
|
|
|
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
node.cachedExpressionType = exprType;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void SanitizeVisitor::Validate(VariableExpression& node)
|
|
|
|
|
{
|
|
|
|
|
if (node.variableId >= m_context->variableTypes.size())
|
|
|
|
|
throw AstError{ "invalid constant index " + std::to_string(node.variableId) };
|
|
|
|
|
|
|
|
|
|
node.cachedExpressionType = m_context->variableTypes[node.variableId];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ExpressionType SanitizeVisitor::ValidateBinaryOp(BinaryType op, const ExpressionPtr& leftExpr, const ExpressionPtr& rightExpr)
|
|
|
|
|
{
|
|
|
|
|
const ExpressionType& leftExprType = GetExpressionType(MandatoryExpr(leftExpr));
|
|
|
|
|
|