Shader: Add support for partial sanitization
This commit is contained in:
@@ -46,6 +46,7 @@ namespace Nz::ShaderAst
|
||||
std::shared_ptr<ShaderModuleResolver> moduleResolver;
|
||||
std::unordered_set<std::string> reservedIdentifiers;
|
||||
std::unordered_map<UInt32, ConstantValue> optionValues;
|
||||
bool allowPartialSanitization = false;
|
||||
bool makeVariableNameUnique = false;
|
||||
bool reduceLoopsToWhile = false;
|
||||
bool removeAliases = false;
|
||||
@@ -60,6 +61,7 @@ namespace Nz::ShaderAst
|
||||
|
||||
private:
|
||||
enum class IdentifierCategory;
|
||||
enum class ValidationResult;
|
||||
struct AstError;
|
||||
struct CurrentFunctionData;
|
||||
struct Environment;
|
||||
@@ -110,10 +112,12 @@ namespace Nz::ShaderAst
|
||||
template<typename F> const IdentifierData* FindIdentifier(const std::string_view& identifierName, F&& functor) const;
|
||||
const IdentifierData* FindIdentifier(const Environment& environment, const std::string_view& identifierName) const;
|
||||
template<typename F> const IdentifierData* FindIdentifier(const Environment& environment, const std::string_view& identifierName, F&& functor) const;
|
||||
TypeParameter FindTypeParameter(const std::string_view& identifierName) const;
|
||||
|
||||
ExpressionPtr HandleIdentifier(const IdentifierData* identifierData);
|
||||
|
||||
const ExpressionType* GetExpressionType(Expression& expr) const;
|
||||
const ExpressionType& GetExpressionTypeSecure(Expression& expr) const;
|
||||
|
||||
Expression& MandatoryExpr(const ExpressionPtr& node) const;
|
||||
Statement& MandatoryStatement(const StatementPtr& node) const;
|
||||
|
||||
@@ -122,8 +126,9 @@ namespace Nz::ShaderAst
|
||||
|
||||
ExpressionPtr CacheResult(ExpressionPtr expression);
|
||||
|
||||
ConstantValue ComputeConstantValue(Expression& expr) const;
|
||||
template<typename T> const T& ComputeExprValue(ExpressionValue<T>& attribute) const;
|
||||
std::optional<ConstantValue> ComputeConstantValue(Expression& expr) const;
|
||||
template<typename T> ValidationResult ComputeExprValue(ExpressionValue<T>& attribute) const;
|
||||
template<typename T> ValidationResult ComputeExprValue(const ExpressionValue<T>& attribute, ExpressionValue<T>& targetAttribute);
|
||||
template<typename T> std::unique_ptr<T> PropagateConstants(T& node) const;
|
||||
|
||||
void PreregisterIndices(const Module& module);
|
||||
@@ -131,49 +136,49 @@ namespace Nz::ShaderAst
|
||||
|
||||
void RegisterBuiltin();
|
||||
|
||||
std::size_t RegisterAlias(std::string name, IdentifierData aliasData, std::optional<std::size_t> index = {});
|
||||
std::size_t RegisterConstant(std::string name, ConstantValue value, std::optional<std::size_t> index = {});
|
||||
std::size_t RegisterFunction(std::string name, FunctionData funcData, std::optional<std::size_t> index = {});
|
||||
std::size_t RegisterAlias(std::string name, std::optional<IdentifierData> aliasData, std::optional<std::size_t> index = {});
|
||||
std::size_t RegisterConstant(std::string name, std::optional<ConstantValue> value, std::optional<std::size_t> index = {});
|
||||
std::size_t RegisterFunction(std::string name, std::optional<FunctionData> funcData, std::optional<std::size_t> index = {});
|
||||
std::size_t RegisterIntrinsic(std::string name, IntrinsicType type);
|
||||
std::size_t RegisterModule(std::string moduleIdentifier, std::size_t moduleIndex);
|
||||
std::size_t RegisterStruct(std::string name, StructDescription* description, std::optional<std::size_t> index = {});
|
||||
std::size_t RegisterType(std::string name, ExpressionType expressionType, std::optional<std::size_t> index = {});
|
||||
std::size_t RegisterType(std::string name, PartialType partialType, std::optional<std::size_t> index = {});
|
||||
std::size_t RegisterVariable(std::string name, ExpressionType type, std::optional<std::size_t> index = {});
|
||||
std::size_t RegisterStruct(std::string name, std::optional<StructDescription*> description, std::optional<std::size_t> index = {});
|
||||
std::size_t RegisterType(std::string name, std::optional<ExpressionType> expressionType, std::optional<std::size_t> index = {});
|
||||
std::size_t RegisterType(std::string name, std::optional<PartialType> partialType, std::optional<std::size_t> index = {});
|
||||
void RegisterUnresolved(std::string name);
|
||||
std::size_t RegisterVariable(std::string name, std::optional<ExpressionType> type, std::optional<std::size_t> index = {});
|
||||
|
||||
const IdentifierData* ResolveAliasIdentifier(const IdentifierData* identifier) const;
|
||||
void ResolveFunctions();
|
||||
const ExpressionPtr& ResolveCondExpression(ConditionalExpression& node);
|
||||
std::size_t ResolveStruct(const AliasType& aliasType);
|
||||
std::size_t ResolveStruct(const ExpressionType& exprType);
|
||||
std::size_t ResolveStruct(const IdentifierType& identifierType);
|
||||
std::size_t ResolveStruct(const StructType& structType);
|
||||
std::size_t ResolveStruct(const UniformType& uniformType);
|
||||
ExpressionType ResolveType(const ExpressionType& exprType, bool resolveAlias = false);
|
||||
ExpressionType ResolveType(const ExpressionValue<ExpressionType>& exprTypeValue, bool resolveAlias = false);
|
||||
std::optional<ExpressionType> ResolveTypeExpr(const ExpressionValue<ExpressionType>& exprTypeValue, bool resolveAlias = false);
|
||||
|
||||
void SanitizeIdentifier(std::string& identifier);
|
||||
MultiStatementPtr SanitizeInternal(MultiStatement& rootNode, std::string* error);
|
||||
|
||||
void TypeMustMatch(const ExpressionPtr& left, const ExpressionPtr& right) const;
|
||||
ValidationResult TypeMustMatch(const ExpressionPtr& left, const ExpressionPtr& right) const;
|
||||
void TypeMustMatch(const ExpressionType& left, const ExpressionType& right) const;
|
||||
|
||||
StatementPtr Unscope(StatementPtr node);
|
||||
|
||||
void Validate(DeclareAliasStatement& node);
|
||||
void Validate(WhileStatement& node);
|
||||
ValidationResult Validate(DeclareAliasStatement& node);
|
||||
ValidationResult Validate(WhileStatement& node);
|
||||
|
||||
void Validate(AccessIndexExpression& node);
|
||||
void Validate(AssignExpression& node);
|
||||
void Validate(BinaryExpression& node);
|
||||
void Validate(CallFunctionExpression& node);
|
||||
void Validate(CastExpression& node);
|
||||
void Validate(DeclareVariableStatement& node);
|
||||
void Validate(IntrinsicExpression& node);
|
||||
void Validate(SwizzleExpression& node);
|
||||
void Validate(UnaryExpression& node);
|
||||
void Validate(VariableValueExpression& node);
|
||||
ExpressionType ValidateBinaryOp(BinaryType op, const ExpressionPtr& leftExpr, const ExpressionPtr& rightExpr);
|
||||
ValidationResult Validate(AccessIndexExpression& node);
|
||||
ValidationResult Validate(AssignExpression& node);
|
||||
ValidationResult Validate(BinaryExpression& node);
|
||||
ValidationResult Validate(CallFunctionExpression& node);
|
||||
ValidationResult Validate(CastExpression& node);
|
||||
ValidationResult Validate(DeclareVariableStatement& node);
|
||||
ValidationResult Validate(IntrinsicExpression& node);
|
||||
ValidationResult Validate(SwizzleExpression& node);
|
||||
ValidationResult Validate(UnaryExpression& node);
|
||||
ValidationResult Validate(VariableValueExpression& node);
|
||||
ExpressionType ValidateBinaryOp(BinaryType op, const ExpressionType& leftExprType, const ExpressionType& rightExprType);
|
||||
|
||||
enum class IdentifierCategory
|
||||
{
|
||||
@@ -184,9 +189,16 @@ namespace Nz::ShaderAst
|
||||
Module,
|
||||
Struct,
|
||||
Type,
|
||||
Unresolved,
|
||||
Variable
|
||||
};
|
||||
|
||||
enum class ValidationResult
|
||||
{
|
||||
Validated,
|
||||
Unresolved
|
||||
};
|
||||
|
||||
struct FunctionData
|
||||
{
|
||||
Bitset<> calledByFunctions;
|
||||
|
||||
Reference in New Issue
Block a user