Shader/SanitizeVisitor: Fix sanitization of already sanitized AST with holes in indices
this happens when you sanitize an AST that went through a remove unused pass
This commit is contained in:
parent
142f15d538
commit
36aea2ca0c
|
|
@ -110,13 +110,13 @@ namespace Nz::ShaderAst
|
||||||
void PropagateFunctionFlags(std::size_t funcIndex, FunctionFlags flags, Bitset<>& seen);
|
void PropagateFunctionFlags(std::size_t funcIndex, FunctionFlags flags, Bitset<>& seen);
|
||||||
|
|
||||||
void RegisterBuiltin();
|
void RegisterBuiltin();
|
||||||
std::size_t RegisterConstant(std::string name, ConstantValue value);
|
std::size_t RegisterConstant(std::string name, ConstantValue value, std::optional<std::size_t> index = {});
|
||||||
std::size_t RegisterFunction(std::string name, FunctionData funcData);
|
std::size_t RegisterFunction(std::string name, FunctionData funcData, std::optional<std::size_t> index = {});
|
||||||
std::size_t RegisterIntrinsic(std::string name, IntrinsicType type);
|
std::size_t RegisterIntrinsic(std::string name, IntrinsicType type, std::optional<std::size_t> index = {});
|
||||||
std::size_t RegisterStruct(std::string name, StructDescription* description);
|
std::size_t RegisterStruct(std::string name, StructDescription* description, std::optional<std::size_t> index = {});
|
||||||
std::size_t RegisterType(std::string name, ExpressionType expressionType);
|
std::size_t RegisterType(std::string name, ExpressionType expressionType, std::optional<std::size_t> index = {});
|
||||||
std::size_t RegisterType(std::string name, PartialType partialType);
|
std::size_t RegisterType(std::string name, PartialType partialType, std::optional<std::size_t> index = {});
|
||||||
std::size_t RegisterVariable(std::string name, ExpressionType type);
|
std::size_t RegisterVariable(std::string name, ExpressionType type, std::optional<std::size_t> index = {});
|
||||||
|
|
||||||
void ResolveFunctions();
|
void ResolveFunctions();
|
||||||
const ExpressionPtr& ResolveCondExpression(ConditionalExpression& node);
|
const ExpressionPtr& ResolveCondExpression(ConditionalExpression& node);
|
||||||
|
|
|
||||||
|
|
@ -42,6 +42,47 @@ namespace Nz::ShaderAst
|
||||||
FunctionFlags flags;
|
FunctionFlags flags;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
struct IdentifierData
|
||||||
|
{
|
||||||
|
Bitset<UInt64> availableIndices;
|
||||||
|
std::unordered_map<std::size_t, T> values;
|
||||||
|
|
||||||
|
template<typename U>
|
||||||
|
std::size_t Register(U&& data, std::optional<std::size_t> index = {})
|
||||||
|
{
|
||||||
|
std::size_t dataIndex;
|
||||||
|
if (index.has_value())
|
||||||
|
dataIndex = *index;
|
||||||
|
else
|
||||||
|
{
|
||||||
|
dataIndex = availableIndices.FindFirst();
|
||||||
|
if (dataIndex == availableIndices.npos)
|
||||||
|
dataIndex = availableIndices.GetSize();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (dataIndex >= availableIndices.GetSize())
|
||||||
|
availableIndices.Resize(dataIndex + 1, true);
|
||||||
|
else if (!availableIndices.Test(dataIndex))
|
||||||
|
throw AstError{ "index " + std::to_string(dataIndex) + " is already used" };
|
||||||
|
|
||||||
|
assert(values.find(dataIndex) == values.end());
|
||||||
|
|
||||||
|
availableIndices.Set(dataIndex, false);
|
||||||
|
values.emplace(dataIndex, std::forward<U>(data));
|
||||||
|
return dataIndex;
|
||||||
|
}
|
||||||
|
|
||||||
|
T& Retrieve(std::size_t index)
|
||||||
|
{
|
||||||
|
auto it = values.find(index);
|
||||||
|
if (it == values.end())
|
||||||
|
throw AstError{ "invalid index " + std::to_string(index) };
|
||||||
|
|
||||||
|
return it->second;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
struct PendingFunction
|
struct PendingFunction
|
||||||
{
|
{
|
||||||
DeclareFunctionStatement* cloneNode;
|
DeclareFunctionStatement* cloneNode;
|
||||||
|
|
@ -56,16 +97,16 @@ namespace Nz::ShaderAst
|
||||||
std::size_t nextOptionIndex = 0;
|
std::size_t nextOptionIndex = 0;
|
||||||
Options options;
|
Options options;
|
||||||
std::array<DeclareFunctionStatement*, ShaderStageTypeCount> entryFunctions = {};
|
std::array<DeclareFunctionStatement*, ShaderStageTypeCount> entryFunctions = {};
|
||||||
|
std::vector<Identifier> identifiersInScope;
|
||||||
|
std::vector<PendingFunction> pendingFunctions;
|
||||||
std::unordered_set<std::string> declaredExternalVar;
|
std::unordered_set<std::string> declaredExternalVar;
|
||||||
std::unordered_set<UInt64> usedBindingIndexes;
|
std::unordered_set<UInt64> usedBindingIndexes;
|
||||||
std::vector<ConstantValue> constantValues;
|
IdentifierData<ConstantValue> constantValues;
|
||||||
std::vector<FunctionData> functions;
|
IdentifierData<FunctionData> functions;
|
||||||
std::vector<Identifier> identifiersInScope;
|
IdentifierData<IntrinsicType> intrinsics;
|
||||||
std::vector<IntrinsicType> intrinsics;
|
IdentifierData<StructDescription*> structs;
|
||||||
std::vector<PendingFunction> pendingFunctions;
|
IdentifierData<std::variant<ExpressionType, PartialType>> types;
|
||||||
std::vector<StructDescription*> structs;
|
IdentifierData<ExpressionType> variableTypes;
|
||||||
std::vector<std::variant<ExpressionType, PartialType>> types;
|
|
||||||
std::vector<ExpressionType> variableTypes;
|
|
||||||
std::vector<Scope> scopes;
|
std::vector<Scope> scopes;
|
||||||
CurrentFunctionData* currentFunction = nullptr;
|
CurrentFunctionData* currentFunction = nullptr;
|
||||||
std::vector<StatementPtr>* currentStatementList = nullptr;
|
std::vector<StatementPtr>* currentStatementList = nullptr;
|
||||||
|
|
@ -185,8 +226,7 @@ namespace Nz::ShaderAst
|
||||||
else if (IsStructType(exprType))
|
else if (IsStructType(exprType))
|
||||||
{
|
{
|
||||||
std::size_t structIndex = ResolveStruct(exprType);
|
std::size_t structIndex = ResolveStruct(exprType);
|
||||||
assert(structIndex < m_context->structs.size());
|
const StructDescription* s = m_context->structs.Retrieve(structIndex);
|
||||||
const StructDescription* s = m_context->structs[structIndex];
|
|
||||||
|
|
||||||
// Retrieve member index (not counting disabled fields)
|
// Retrieve member index (not counting disabled fields)
|
||||||
Int32 fieldIndex = 0;
|
Int32 fieldIndex = 0;
|
||||||
|
|
@ -513,11 +553,8 @@ namespace Nz::ShaderAst
|
||||||
|
|
||||||
ExpressionPtr SanitizeVisitor::Clone(ConstantExpression& node)
|
ExpressionPtr SanitizeVisitor::Clone(ConstantExpression& node)
|
||||||
{
|
{
|
||||||
if (node.constantId >= m_context->constantValues.size())
|
|
||||||
throw AstError{ "invalid constant index " + std::to_string(node.constantId) };
|
|
||||||
|
|
||||||
// Replace by constant value
|
// Replace by constant value
|
||||||
auto constant = ShaderBuilder::Constant(m_context->constantValues[node.constantId]);
|
auto constant = ShaderBuilder::Constant(m_context->constantValues.Retrieve(node.constantId));
|
||||||
constant->cachedExpressionType = GetExpressionType(constant->value);
|
constant->cachedExpressionType = GetExpressionType(constant->value);
|
||||||
|
|
||||||
return constant;
|
return constant;
|
||||||
|
|
@ -552,8 +589,7 @@ namespace Nz::ShaderAst
|
||||||
|
|
||||||
case Identifier::Type::Intrinsic:
|
case Identifier::Type::Intrinsic:
|
||||||
{
|
{
|
||||||
assert(identifier->index < m_context->intrinsics.size());
|
IntrinsicType intrinsicType = m_context->intrinsics.Retrieve(identifier->index);
|
||||||
IntrinsicType intrinsicType = m_context->intrinsics[identifier->index];
|
|
||||||
|
|
||||||
auto clone = AstCloner::Clone(node);
|
auto clone = AstCloner::Clone(node);
|
||||||
clone->cachedExpressionType = IntrinsicFunctionType{ intrinsicType };
|
clone->cachedExpressionType = IntrinsicFunctionType{ intrinsicType };
|
||||||
|
|
@ -581,7 +617,7 @@ namespace Nz::ShaderAst
|
||||||
{
|
{
|
||||||
// Replace IdentifierExpression by VariableExpression
|
// Replace IdentifierExpression by VariableExpression
|
||||||
auto varExpr = std::make_unique<VariableExpression>();
|
auto varExpr = std::make_unique<VariableExpression>();
|
||||||
varExpr->cachedExpressionType = m_context->variableTypes[identifier->index];
|
varExpr->cachedExpressionType = m_context->variableTypes.Retrieve(identifier->index);
|
||||||
varExpr->variableId = identifier->index;
|
varExpr->variableId = identifier->index;
|
||||||
|
|
||||||
return varExpr;
|
return varExpr;
|
||||||
|
|
@ -732,7 +768,7 @@ namespace Nz::ShaderAst
|
||||||
|
|
||||||
clone->type = expressionType;
|
clone->type = expressionType;
|
||||||
|
|
||||||
clone->constIndex = RegisterConstant(clone->name, value);
|
clone->constIndex = RegisterConstant(clone->name, value, clone->constIndex);
|
||||||
|
|
||||||
if (m_context->options.removeConstDeclaration)
|
if (m_context->options.removeConstDeclaration)
|
||||||
return ShaderBuilder::NoOp();
|
return ShaderBuilder::NoOp();
|
||||||
|
|
@ -786,7 +822,7 @@ namespace Nz::ShaderAst
|
||||||
throw AstError{ "external variable " + extVar.name + " is of wrong type: only uniform and sampler are allowed in external blocks" };
|
throw AstError{ "external variable " + extVar.name + " is of wrong type: only uniform and sampler are allowed in external blocks" };
|
||||||
|
|
||||||
extVar.type = std::move(resolvedType);
|
extVar.type = std::move(resolvedType);
|
||||||
extVar.varIndex = RegisterVariable(extVar.name, std::move(varType));
|
extVar.varIndex = RegisterVariable(extVar.name, std::move(varType), extVar.varIndex);
|
||||||
|
|
||||||
SanitizeIdentifier(extVar.name);
|
SanitizeIdentifier(extVar.name);
|
||||||
}
|
}
|
||||||
|
|
@ -858,7 +894,7 @@ namespace Nz::ShaderAst
|
||||||
FunctionData funcData;
|
FunctionData funcData;
|
||||||
funcData.node = clone.get(); //< update function node
|
funcData.node = clone.get(); //< update function node
|
||||||
|
|
||||||
std::size_t funcIndex = RegisterFunction(clone->name, std::move(funcData));
|
std::size_t funcIndex = RegisterFunction(clone->name, std::move(funcData), node.funcIndex);
|
||||||
clone->funcIndex = funcIndex;
|
clone->funcIndex = funcIndex;
|
||||||
|
|
||||||
SanitizeIdentifier(clone->name);
|
SanitizeIdentifier(clone->name);
|
||||||
|
|
@ -883,9 +919,9 @@ namespace Nz::ShaderAst
|
||||||
std::size_t optionIndex = m_context->nextOptionIndex++;
|
std::size_t optionIndex = m_context->nextOptionIndex++;
|
||||||
|
|
||||||
if (auto optionValueIt = m_context->options.optionValues.find(optionIndex); optionValueIt != m_context->options.optionValues.end())
|
if (auto optionValueIt = m_context->options.optionValues.find(optionIndex); optionValueIt != m_context->options.optionValues.end())
|
||||||
clone->optIndex = RegisterConstant(clone->optName, optionValueIt->second);
|
clone->optIndex = RegisterConstant(clone->optName, optionValueIt->second, clone->optIndex);
|
||||||
else if (clone->defaultValue)
|
else if (clone->defaultValue)
|
||||||
clone->optIndex = RegisterConstant(clone->optName, ComputeConstantValue(*clone->defaultValue));
|
clone->optIndex = RegisterConstant(clone->optName, ComputeConstantValue(*clone->defaultValue), clone->optIndex);
|
||||||
else
|
else
|
||||||
throw AstError{ "missing option " + clone->optName + " value (has no default value)" };
|
throw AstError{ "missing option " + clone->optName + " value (has no default value)" };
|
||||||
|
|
||||||
|
|
@ -931,7 +967,7 @@ namespace Nz::ShaderAst
|
||||||
else if (IsStructType(resolvedType))
|
else if (IsStructType(resolvedType))
|
||||||
{
|
{
|
||||||
std::size_t structIndex = std::get<StructType>(resolvedType).structIndex;
|
std::size_t structIndex = std::get<StructType>(resolvedType).structIndex;
|
||||||
const StructDescription* desc = m_context->structs[structIndex];
|
const StructDescription* desc = m_context->structs.Retrieve(structIndex);
|
||||||
if (!desc->layout.HasValue() || desc->layout.GetResultingValue() != clone->description.layout.GetResultingValue())
|
if (!desc->layout.HasValue() || desc->layout.GetResultingValue() != clone->description.layout.GetResultingValue())
|
||||||
throw AstError{ "inner struct layout mismatch" };
|
throw AstError{ "inner struct layout mismatch" };
|
||||||
}
|
}
|
||||||
|
|
@ -940,7 +976,7 @@ namespace Nz::ShaderAst
|
||||||
member.type = std::move(resolvedType);
|
member.type = std::move(resolvedType);
|
||||||
}
|
}
|
||||||
|
|
||||||
clone->structIndex = RegisterStruct(clone->description.name, &clone->description);
|
clone->structIndex = RegisterStruct(clone->description.name, &clone->description, clone->structIndex);
|
||||||
|
|
||||||
SanitizeIdentifier(clone->description.name);
|
SanitizeIdentifier(clone->description.name);
|
||||||
|
|
||||||
|
|
@ -1347,7 +1383,7 @@ namespace Nz::ShaderAst
|
||||||
switch (identifier->type)
|
switch (identifier->type)
|
||||||
{
|
{
|
||||||
case Identifier::Type::Constant:
|
case Identifier::Type::Constant:
|
||||||
return m_context->constantValues[identifier->index];
|
return m_context->constantValues.Retrieve(identifier->index);
|
||||||
|
|
||||||
case Identifier::Type::Struct:
|
case Identifier::Type::Struct:
|
||||||
return StructType{ identifier->index };
|
return StructType{ identifier->index };
|
||||||
|
|
@ -1356,7 +1392,7 @@ namespace Nz::ShaderAst
|
||||||
return std::visit([&](auto&& arg) -> TypeParameter
|
return std::visit([&](auto&& arg) -> TypeParameter
|
||||||
{
|
{
|
||||||
return arg;
|
return arg;
|
||||||
}, m_context->types[identifier->index]);
|
}, m_context->types.Retrieve(identifier->index));
|
||||||
|
|
||||||
case Identifier::Type::Alias:
|
case Identifier::Type::Alias:
|
||||||
throw std::runtime_error("TODO");
|
throw std::runtime_error("TODO");
|
||||||
|
|
@ -1469,8 +1505,7 @@ namespace Nz::ShaderAst
|
||||||
AstConstantPropagationVisitor::Options optimizerOptions;
|
AstConstantPropagationVisitor::Options optimizerOptions;
|
||||||
optimizerOptions.constantQueryCallback = [this](std::size_t constantId) -> const ConstantValue&
|
optimizerOptions.constantQueryCallback = [this](std::size_t constantId) -> const ConstantValue&
|
||||||
{
|
{
|
||||||
assert(constantId < m_context->constantValues.size());
|
return m_context->constantValues.Retrieve(constantId);
|
||||||
return m_context->constantValues[constantId];
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// Run optimizer on constant value to hopefully retrieve a single constant value
|
// Run optimizer on constant value to hopefully retrieve a single constant value
|
||||||
|
|
@ -1479,8 +1514,7 @@ namespace Nz::ShaderAst
|
||||||
|
|
||||||
void SanitizeVisitor::PropagateFunctionFlags(std::size_t funcIndex, FunctionFlags flags, Bitset<>& seen)
|
void SanitizeVisitor::PropagateFunctionFlags(std::size_t funcIndex, FunctionFlags flags, Bitset<>& seen)
|
||||||
{
|
{
|
||||||
assert(funcIndex < m_context->functions.size());
|
auto& funcData = m_context->functions.Retrieve(funcIndex);
|
||||||
auto& funcData = m_context->functions[funcIndex];
|
|
||||||
funcData.flags |= flags;
|
funcData.flags |= flags;
|
||||||
|
|
||||||
for (std::size_t i = funcData.calledByFunctions.FindFirst(); i != funcData.calledByFunctions.npos; i = funcData.calledByFunctions.FindNext(i))
|
for (std::size_t i = funcData.calledByFunctions.FindFirst(); i != funcData.calledByFunctions.npos; i = funcData.calledByFunctions.FindNext(i))
|
||||||
|
|
@ -1651,13 +1685,12 @@ namespace Nz::ShaderAst
|
||||||
RegisterIntrinsic("reflect", IntrinsicType::Reflect);
|
RegisterIntrinsic("reflect", IntrinsicType::Reflect);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::size_t SanitizeVisitor::RegisterConstant(std::string name, ConstantValue value)
|
std::size_t SanitizeVisitor::RegisterConstant(std::string name, ConstantValue value, std::optional<std::size_t> index)
|
||||||
{
|
{
|
||||||
if (FindIdentifier(name))
|
if (FindIdentifier(name))
|
||||||
throw AstError{ name + " is already used" };
|
throw AstError{ name + " is already used" };
|
||||||
|
|
||||||
std::size_t constantIndex = m_context->constantValues.size();
|
std::size_t constantIndex = m_context->constantValues.Register(std::move(value), index);
|
||||||
m_context->constantValues.emplace_back(std::move(value));
|
|
||||||
|
|
||||||
m_context->identifiersInScope.push_back({
|
m_context->identifiersInScope.push_back({
|
||||||
std::move(name),
|
std::move(name),
|
||||||
|
|
@ -1668,7 +1701,7 @@ namespace Nz::ShaderAst
|
||||||
return constantIndex;
|
return constantIndex;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::size_t SanitizeVisitor::RegisterFunction(std::string name, FunctionData funcData)
|
std::size_t SanitizeVisitor::RegisterFunction(std::string name, FunctionData funcData, std::optional<std::size_t> index)
|
||||||
{
|
{
|
||||||
if (auto* identifier = FindIdentifier(name))
|
if (auto* identifier = FindIdentifier(name))
|
||||||
{
|
{
|
||||||
|
|
@ -1677,7 +1710,7 @@ namespace Nz::ShaderAst
|
||||||
// Functions cannot be declared twice, except for entry ones if their stages are different
|
// Functions cannot be declared twice, except for entry ones if their stages are different
|
||||||
if (funcData.node->entryStage.HasValue() && identifier->type == Identifier::Type::Function)
|
if (funcData.node->entryStage.HasValue() && identifier->type == Identifier::Type::Function)
|
||||||
{
|
{
|
||||||
auto& otherFunction = m_context->functions[identifier->index];
|
auto& otherFunction = m_context->functions.Retrieve(identifier->index);
|
||||||
if (funcData.node->entryStage.GetResultingValue() != otherFunction.node->entryStage.GetResultingValue())
|
if (funcData.node->entryStage.GetResultingValue() != otherFunction.node->entryStage.GetResultingValue())
|
||||||
duplicate = false;
|
duplicate = false;
|
||||||
}
|
}
|
||||||
|
|
@ -1686,9 +1719,7 @@ namespace Nz::ShaderAst
|
||||||
throw AstError{ funcData.node->name + " is already used" };
|
throw AstError{ funcData.node->name + " is already used" };
|
||||||
}
|
}
|
||||||
|
|
||||||
std::size_t functionIndex = m_context->functions.size();
|
std::size_t functionIndex = m_context->functions.Register(std::move(funcData), index);
|
||||||
|
|
||||||
m_context->functions.emplace_back(std::move(funcData));
|
|
||||||
|
|
||||||
m_context->identifiersInScope.push_back({
|
m_context->identifiersInScope.push_back({
|
||||||
std::move(name),
|
std::move(name),
|
||||||
|
|
@ -1699,13 +1730,12 @@ namespace Nz::ShaderAst
|
||||||
return functionIndex;
|
return functionIndex;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::size_t SanitizeVisitor::RegisterIntrinsic(std::string name, IntrinsicType type)
|
std::size_t SanitizeVisitor::RegisterIntrinsic(std::string name, IntrinsicType type, std::optional<std::size_t> index)
|
||||||
{
|
{
|
||||||
if (FindIdentifier(name))
|
if (FindIdentifier(name))
|
||||||
throw AstError{ name + " is already used" };
|
throw AstError{ name + " is already used" };
|
||||||
|
|
||||||
std::size_t intrinsicIndex = m_context->intrinsics.size();
|
std::size_t intrinsicIndex = m_context->intrinsics.Register(std::move(type), index);
|
||||||
m_context->intrinsics.push_back(type);
|
|
||||||
|
|
||||||
m_context->identifiersInScope.push_back({
|
m_context->identifiersInScope.push_back({
|
||||||
std::move(name),
|
std::move(name),
|
||||||
|
|
@ -1716,13 +1746,12 @@ namespace Nz::ShaderAst
|
||||||
return intrinsicIndex;
|
return intrinsicIndex;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::size_t SanitizeVisitor::RegisterStruct(std::string name, StructDescription* description)
|
std::size_t SanitizeVisitor::RegisterStruct(std::string name, StructDescription* description, std::optional<std::size_t> index)
|
||||||
{
|
{
|
||||||
if (FindIdentifier(name))
|
if (FindIdentifier(name))
|
||||||
throw AstError{ name + " is already used" };
|
throw AstError{ name + " is already used" };
|
||||||
|
|
||||||
std::size_t structIndex = m_context->structs.size();
|
std::size_t structIndex = m_context->structs.Register(description, index);
|
||||||
m_context->structs.emplace_back(description);
|
|
||||||
|
|
||||||
m_context->identifiersInScope.push_back({
|
m_context->identifiersInScope.push_back({
|
||||||
std::move(name),
|
std::move(name),
|
||||||
|
|
@ -1733,13 +1762,12 @@ namespace Nz::ShaderAst
|
||||||
return structIndex;
|
return structIndex;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::size_t SanitizeVisitor::RegisterType(std::string name, ExpressionType expressionType)
|
std::size_t SanitizeVisitor::RegisterType(std::string name, ExpressionType expressionType, std::optional<std::size_t> index)
|
||||||
{
|
{
|
||||||
if (FindIdentifier(name))
|
if (FindIdentifier(name))
|
||||||
throw AstError{ name + " is already used" };
|
throw AstError{ name + " is already used" };
|
||||||
|
|
||||||
std::size_t typeIndex = m_context->types.size();
|
std::size_t typeIndex = m_context->types.Register(std::move(expressionType), index);
|
||||||
m_context->types.emplace_back(std::move(expressionType));
|
|
||||||
|
|
||||||
m_context->identifiersInScope.push_back({
|
m_context->identifiersInScope.push_back({
|
||||||
std::move(name),
|
std::move(name),
|
||||||
|
|
@ -1750,13 +1778,12 @@ namespace Nz::ShaderAst
|
||||||
return typeIndex;
|
return typeIndex;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::size_t SanitizeVisitor::RegisterType(std::string name, PartialType partialType)
|
std::size_t SanitizeVisitor::RegisterType(std::string name, PartialType partialType, std::optional<std::size_t> index)
|
||||||
{
|
{
|
||||||
if (FindIdentifier(name))
|
if (FindIdentifier(name))
|
||||||
throw AstError{ name + " is already used" };
|
throw AstError{ name + " is already used" };
|
||||||
|
|
||||||
std::size_t typeIndex = m_context->types.size();
|
std::size_t typeIndex = m_context->types.Register(std::move(partialType), index);
|
||||||
m_context->types.emplace_back(std::move(partialType));
|
|
||||||
|
|
||||||
m_context->identifiersInScope.push_back({
|
m_context->identifiersInScope.push_back({
|
||||||
std::move(name),
|
std::move(name),
|
||||||
|
|
@ -1767,7 +1794,7 @@ namespace Nz::ShaderAst
|
||||||
return typeIndex;
|
return typeIndex;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::size_t SanitizeVisitor::RegisterVariable(std::string name, ExpressionType type)
|
std::size_t SanitizeVisitor::RegisterVariable(std::string name, ExpressionType type, std::optional<std::size_t> index)
|
||||||
{
|
{
|
||||||
if (auto* identifier = FindIdentifier(name))
|
if (auto* identifier = FindIdentifier(name))
|
||||||
{
|
{
|
||||||
|
|
@ -1776,8 +1803,7 @@ namespace Nz::ShaderAst
|
||||||
throw AstError{ name + " is already used" };
|
throw AstError{ name + " is already used" };
|
||||||
}
|
}
|
||||||
|
|
||||||
std::size_t varIndex = m_context->variableTypes.size();
|
std::size_t varIndex = m_context->variableTypes.Register(std::move(type), index);
|
||||||
m_context->variableTypes.emplace_back(std::move(type));
|
|
||||||
|
|
||||||
m_context->identifiersInScope.push_back({
|
m_context->identifiersInScope.push_back({
|
||||||
std::move(name),
|
std::move(name),
|
||||||
|
|
@ -1795,13 +1821,16 @@ namespace Nz::ShaderAst
|
||||||
{
|
{
|
||||||
PushScope();
|
PushScope();
|
||||||
|
|
||||||
|
std::optional<std::size_t> varIndex = pendingFunc.cloneNode->varIndex;
|
||||||
for (auto& parameter : pendingFunc.cloneNode->parameters)
|
for (auto& parameter : pendingFunc.cloneNode->parameters)
|
||||||
{
|
{
|
||||||
std::size_t varIndex = RegisterVariable(parameter.name, parameter.type.GetResultingValue());
|
std::size_t index = RegisterVariable(parameter.name, parameter.type.GetResultingValue(), varIndex);
|
||||||
if (!pendingFunc.cloneNode->varIndex)
|
if (!pendingFunc.cloneNode->varIndex)
|
||||||
pendingFunc.cloneNode->varIndex = varIndex; //< First parameter variable index is node variable index
|
pendingFunc.cloneNode->varIndex = index; //< First parameter variable index is node variable index
|
||||||
|
|
||||||
SanitizeIdentifier(parameter.name);
|
SanitizeIdentifier(parameter.name);
|
||||||
|
if (varIndex)
|
||||||
|
(*varIndex)++;
|
||||||
}
|
}
|
||||||
|
|
||||||
Context::CurrentFunctionData tempFuncData;
|
Context::CurrentFunctionData tempFuncData;
|
||||||
|
|
@ -1823,8 +1852,7 @@ namespace Nz::ShaderAst
|
||||||
std::size_t funcIndex = *pendingFunc.cloneNode->funcIndex;
|
std::size_t funcIndex = *pendingFunc.cloneNode->funcIndex;
|
||||||
for (std::size_t i = tempFuncData.calledFunctions.FindFirst(); i != tempFuncData.calledFunctions.npos; i = tempFuncData.calledFunctions.FindNext(i))
|
for (std::size_t i = tempFuncData.calledFunctions.FindFirst(); i != tempFuncData.calledFunctions.npos; i = tempFuncData.calledFunctions.FindNext(i))
|
||||||
{
|
{
|
||||||
assert(i < m_context->functions.size());
|
auto& targetFunc = m_context->functions.Retrieve(i);
|
||||||
auto& targetFunc = m_context->functions[i];
|
|
||||||
targetFunc.calledByFunctions.UnboundedSet(funcIndex);
|
targetFunc.calledByFunctions.UnboundedSet(funcIndex);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1832,15 +1860,13 @@ namespace Nz::ShaderAst
|
||||||
}
|
}
|
||||||
|
|
||||||
Bitset<> seen;
|
Bitset<> seen;
|
||||||
for (std::size_t funcIndex = 0; funcIndex < m_context->functions.size(); ++funcIndex)
|
for (const auto& [funcIndex, funcData] : m_context->functions.values)
|
||||||
{
|
{
|
||||||
auto& funcData = m_context->functions[funcIndex];
|
|
||||||
|
|
||||||
PropagateFunctionFlags(funcIndex, funcData.flags, seen);
|
PropagateFunctionFlags(funcIndex, funcData.flags, seen);
|
||||||
seen.Clear();
|
seen.Clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
for (const FunctionData& funcData : m_context->functions)
|
for (const auto& [funcIndex, funcData] : m_context->functions.values)
|
||||||
{
|
{
|
||||||
if (funcData.flags.Test(ShaderAst::FunctionFlag::DoesDiscard) && funcData.node->entryStage.HasValue() && funcData.node->entryStage.GetResultingValue() != ShaderStageType::Fragment)
|
if (funcData.flags.Test(ShaderAst::FunctionFlag::DoesDiscard) && funcData.node->entryStage.HasValue() && funcData.node->entryStage.GetResultingValue() != ShaderStageType::Fragment)
|
||||||
throw AstError{ "discard can only be used in the fragment stage" };
|
throw AstError{ "discard can only be used in the fragment stage" };
|
||||||
|
|
@ -1919,7 +1945,7 @@ namespace Nz::ShaderAst
|
||||||
|
|
||||||
std::size_t typeIndex = std::get<Type>(exprType).typeIndex;
|
std::size_t typeIndex = std::get<Type>(exprType).typeIndex;
|
||||||
|
|
||||||
const auto& type = m_context->types[typeIndex];
|
const auto& type = m_context->types.Retrieve(typeIndex);
|
||||||
if (std::holds_alternative<PartialType>(type))
|
if (std::holds_alternative<PartialType>(type))
|
||||||
throw AstError{ "full type expected" };
|
throw AstError{ "full type expected" };
|
||||||
|
|
||||||
|
|
@ -1991,7 +2017,7 @@ namespace Nz::ShaderAst
|
||||||
if (IsTypeExpression(exprType))
|
if (IsTypeExpression(exprType))
|
||||||
{
|
{
|
||||||
std::size_t typeIndex = std::get<Type>(exprType).typeIndex;
|
std::size_t typeIndex = std::get<Type>(exprType).typeIndex;
|
||||||
const auto& type = m_context->types[typeIndex];
|
const auto& type = m_context->types.Retrieve(typeIndex);
|
||||||
|
|
||||||
if (!std::holds_alternative<PartialType>(type))
|
if (!std::holds_alternative<PartialType>(type))
|
||||||
throw std::runtime_error("only partial types can be specialized");
|
throw std::runtime_error("only partial types can be specialized");
|
||||||
|
|
@ -2084,8 +2110,7 @@ namespace Nz::ShaderAst
|
||||||
Int32 index = std::get<Int32>(constantExpr.value);
|
Int32 index = std::get<Int32>(constantExpr.value);
|
||||||
|
|
||||||
std::size_t structIndex = ResolveStruct(exprType);
|
std::size_t structIndex = ResolveStruct(exprType);
|
||||||
assert(structIndex < m_context->structs.size());
|
const StructDescription* s = m_context->structs.Retrieve(structIndex);
|
||||||
const StructDescription* s = m_context->structs[structIndex];
|
|
||||||
|
|
||||||
exprType = ResolveType(s->members[index].type);
|
exprType = ResolveType(s->members[index].type);
|
||||||
}
|
}
|
||||||
|
|
@ -2159,8 +2184,7 @@ namespace Nz::ShaderAst
|
||||||
assert(std::holds_alternative<FunctionType>(targetFuncType));
|
assert(std::holds_alternative<FunctionType>(targetFuncType));
|
||||||
|
|
||||||
std::size_t targetFuncIndex = std::get<FunctionType>(targetFuncType).funcIndex;
|
std::size_t targetFuncIndex = std::get<FunctionType>(targetFuncType).funcIndex;
|
||||||
assert(targetFuncIndex < m_context->functions.size());
|
auto& funcData = m_context->functions.Retrieve(targetFuncIndex);
|
||||||
auto& funcData = m_context->functions[targetFuncIndex];
|
|
||||||
|
|
||||||
const DeclareFunctionStatement* referenceDeclaration = funcData.node;
|
const DeclareFunctionStatement* referenceDeclaration = funcData.node;
|
||||||
|
|
||||||
|
|
@ -2271,7 +2295,7 @@ namespace Nz::ShaderAst
|
||||||
TypeMustMatch(resolvedType, GetExpressionType(*node.initialExpression));
|
TypeMustMatch(resolvedType, GetExpressionType(*node.initialExpression));
|
||||||
}
|
}
|
||||||
|
|
||||||
node.varIndex = RegisterVariable(node.varName, resolvedType);
|
node.varIndex = RegisterVariable(node.varName, resolvedType, node.varIndex);
|
||||||
node.varType = std::move(resolvedType);
|
node.varType = std::move(resolvedType);
|
||||||
|
|
||||||
if (m_context->options.makeVariableNameUnique)
|
if (m_context->options.makeVariableNameUnique)
|
||||||
|
|
@ -2523,10 +2547,7 @@ namespace Nz::ShaderAst
|
||||||
|
|
||||||
void SanitizeVisitor::Validate(VariableExpression& node)
|
void SanitizeVisitor::Validate(VariableExpression& node)
|
||||||
{
|
{
|
||||||
if (node.variableId >= m_context->variableTypes.size())
|
node.cachedExpressionType = m_context->variableTypes.Retrieve(node.variableId);
|
||||||
throw AstError{ "invalid constant index " + std::to_string(node.variableId) };
|
|
||||||
|
|
||||||
node.cachedExpressionType = m_context->variableTypes[node.variableId];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ExpressionType SanitizeVisitor::ValidateBinaryOp(BinaryType op, const ExpressionPtr& leftExpr, const ExpressionPtr& rightExpr)
|
ExpressionType SanitizeVisitor::ValidateBinaryOp(BinaryType op, const ExpressionPtr& leftExpr, const ExpressionPtr& rightExpr)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue