Shader: Add support for partial sanitization

This commit is contained in:
SirLynix
2022-03-25 12:54:51 +01:00
parent a54f70fd24
commit 8146ec251a
31 changed files with 1105 additions and 521 deletions

View File

@@ -46,6 +46,7 @@ namespace Nz::ShaderAst
std::shared_ptr<ShaderModuleResolver> moduleResolver;
std::unordered_set<std::string> reservedIdentifiers;
std::unordered_map<UInt32, ConstantValue> optionValues;
bool allowPartialSanitization = false;
bool makeVariableNameUnique = false;
bool reduceLoopsToWhile = false;
bool removeAliases = false;
@@ -60,6 +61,7 @@ namespace Nz::ShaderAst
private:
enum class IdentifierCategory;
enum class ValidationResult;
struct AstError;
struct CurrentFunctionData;
struct Environment;
@@ -110,10 +112,12 @@ namespace Nz::ShaderAst
template<typename F> const IdentifierData* FindIdentifier(const std::string_view& identifierName, F&& functor) const;
const IdentifierData* FindIdentifier(const Environment& environment, const std::string_view& identifierName) const;
template<typename F> const IdentifierData* FindIdentifier(const Environment& environment, const std::string_view& identifierName, F&& functor) const;
TypeParameter FindTypeParameter(const std::string_view& identifierName) const;
ExpressionPtr HandleIdentifier(const IdentifierData* identifierData);
const ExpressionType* GetExpressionType(Expression& expr) const;
const ExpressionType& GetExpressionTypeSecure(Expression& expr) const;
Expression& MandatoryExpr(const ExpressionPtr& node) const;
Statement& MandatoryStatement(const StatementPtr& node) const;
@@ -122,8 +126,9 @@ namespace Nz::ShaderAst
ExpressionPtr CacheResult(ExpressionPtr expression);
ConstantValue ComputeConstantValue(Expression& expr) const;
template<typename T> const T& ComputeExprValue(ExpressionValue<T>& attribute) const;
std::optional<ConstantValue> ComputeConstantValue(Expression& expr) const;
template<typename T> ValidationResult ComputeExprValue(ExpressionValue<T>& attribute) const;
template<typename T> ValidationResult ComputeExprValue(const ExpressionValue<T>& attribute, ExpressionValue<T>& targetAttribute);
template<typename T> std::unique_ptr<T> PropagateConstants(T& node) const;
void PreregisterIndices(const Module& module);
@@ -131,49 +136,49 @@ namespace Nz::ShaderAst
void RegisterBuiltin();
std::size_t RegisterAlias(std::string name, IdentifierData aliasData, std::optional<std::size_t> index = {});
std::size_t RegisterConstant(std::string name, ConstantValue value, std::optional<std::size_t> index = {});
std::size_t RegisterFunction(std::string name, FunctionData funcData, std::optional<std::size_t> index = {});
std::size_t RegisterAlias(std::string name, std::optional<IdentifierData> aliasData, std::optional<std::size_t> index = {});
std::size_t RegisterConstant(std::string name, std::optional<ConstantValue> value, std::optional<std::size_t> index = {});
std::size_t RegisterFunction(std::string name, std::optional<FunctionData> funcData, std::optional<std::size_t> index = {});
std::size_t RegisterIntrinsic(std::string name, IntrinsicType type);
std::size_t RegisterModule(std::string moduleIdentifier, std::size_t moduleIndex);
std::size_t RegisterStruct(std::string name, StructDescription* description, std::optional<std::size_t> index = {});
std::size_t RegisterType(std::string name, ExpressionType expressionType, std::optional<std::size_t> index = {});
std::size_t RegisterType(std::string name, PartialType partialType, std::optional<std::size_t> index = {});
std::size_t RegisterVariable(std::string name, ExpressionType type, std::optional<std::size_t> index = {});
std::size_t RegisterStruct(std::string name, std::optional<StructDescription*> description, std::optional<std::size_t> index = {});
std::size_t RegisterType(std::string name, std::optional<ExpressionType> expressionType, std::optional<std::size_t> index = {});
std::size_t RegisterType(std::string name, std::optional<PartialType> partialType, std::optional<std::size_t> index = {});
void RegisterUnresolved(std::string name);
std::size_t RegisterVariable(std::string name, std::optional<ExpressionType> type, std::optional<std::size_t> index = {});
const IdentifierData* ResolveAliasIdentifier(const IdentifierData* identifier) const;
void ResolveFunctions();
const ExpressionPtr& ResolveCondExpression(ConditionalExpression& node);
std::size_t ResolveStruct(const AliasType& aliasType);
std::size_t ResolveStruct(const ExpressionType& exprType);
std::size_t ResolveStruct(const IdentifierType& identifierType);
std::size_t ResolveStruct(const StructType& structType);
std::size_t ResolveStruct(const UniformType& uniformType);
ExpressionType ResolveType(const ExpressionType& exprType, bool resolveAlias = false);
ExpressionType ResolveType(const ExpressionValue<ExpressionType>& exprTypeValue, bool resolveAlias = false);
std::optional<ExpressionType> ResolveTypeExpr(const ExpressionValue<ExpressionType>& exprTypeValue, bool resolveAlias = false);
void SanitizeIdentifier(std::string& identifier);
MultiStatementPtr SanitizeInternal(MultiStatement& rootNode, std::string* error);
void TypeMustMatch(const ExpressionPtr& left, const ExpressionPtr& right) const;
ValidationResult TypeMustMatch(const ExpressionPtr& left, const ExpressionPtr& right) const;
void TypeMustMatch(const ExpressionType& left, const ExpressionType& right) const;
StatementPtr Unscope(StatementPtr node);
void Validate(DeclareAliasStatement& node);
void Validate(WhileStatement& node);
ValidationResult Validate(DeclareAliasStatement& node);
ValidationResult Validate(WhileStatement& node);
void Validate(AccessIndexExpression& node);
void Validate(AssignExpression& node);
void Validate(BinaryExpression& node);
void Validate(CallFunctionExpression& node);
void Validate(CastExpression& node);
void Validate(DeclareVariableStatement& node);
void Validate(IntrinsicExpression& node);
void Validate(SwizzleExpression& node);
void Validate(UnaryExpression& node);
void Validate(VariableValueExpression& node);
ExpressionType ValidateBinaryOp(BinaryType op, const ExpressionPtr& leftExpr, const ExpressionPtr& rightExpr);
ValidationResult Validate(AccessIndexExpression& node);
ValidationResult Validate(AssignExpression& node);
ValidationResult Validate(BinaryExpression& node);
ValidationResult Validate(CallFunctionExpression& node);
ValidationResult Validate(CastExpression& node);
ValidationResult Validate(DeclareVariableStatement& node);
ValidationResult Validate(IntrinsicExpression& node);
ValidationResult Validate(SwizzleExpression& node);
ValidationResult Validate(UnaryExpression& node);
ValidationResult Validate(VariableValueExpression& node);
ExpressionType ValidateBinaryOp(BinaryType op, const ExpressionType& leftExprType, const ExpressionType& rightExprType);
enum class IdentifierCategory
{
@@ -184,9 +189,16 @@ namespace Nz::ShaderAst
Module,
Struct,
Type,
Unresolved,
Variable
};
enum class ValidationResult
{
Validated,
Unresolved
};
struct FunctionData
{
Bitset<> calledByFunctions;