Shader: Add support for partial sanitization

This commit is contained in:
SirLynix
2022-03-25 12:54:51 +01:00
parent a54f70fd24
commit 8146ec251a
31 changed files with 1105 additions and 521 deletions

View File

@@ -65,7 +65,7 @@ namespace Nz
NazaraSignal(OnShaderUpdated, UberShader* /*uberShader*/);
private:
void Validate(ShaderAst::Module& module);
ShaderAst::ModulePtr Validate(const ShaderAst::Module& module, std::unordered_map<std::string, Option>* options);
NazaraSlot(ShaderModuleResolver, OnModuleUpdated, m_onShaderModuleUpdated);

View File

@@ -57,6 +57,7 @@ namespace Nz::ShaderAst
virtual ExpressionPtr Clone(IntrinsicFunctionExpression& node);
virtual ExpressionPtr Clone(StructTypeExpression& node);
virtual ExpressionPtr Clone(SwizzleExpression& node);
virtual ExpressionPtr Clone(TypeExpression& node);
virtual ExpressionPtr Clone(VariableValueExpression& node);
virtual ExpressionPtr Clone(UnaryExpression& node);

View File

@@ -49,6 +49,7 @@ namespace Nz::ShaderAst
inline bool Compare(const IntrinsicFunctionExpression& lhs, const IntrinsicFunctionExpression& rhs);
inline bool Compare(const StructTypeExpression& lhs, const StructTypeExpression& rhs);
inline bool Compare(const SwizzleExpression& lhs, const SwizzleExpression& rhs);
inline bool Compare(const TypeExpression& lhs, const TypeExpression& rhs);
inline bool Compare(const VariableValueExpression& lhs, const VariableValueExpression& rhs);
inline bool Compare(const UnaryExpression& lhs, const UnaryExpression& rhs);

View File

@@ -407,6 +407,14 @@ namespace Nz::ShaderAst
return true;
}
bool Compare(const TypeExpression& lhs, const TypeExpression& rhs)
{
if (!Compare(lhs.typeId, rhs.typeId))
return false;
return true;
}
inline bool Compare(const VariableValueExpression& lhs, const VariableValueExpression& rhs)
{
if (!Compare(lhs.variableId, rhs.variableId))

View File

@@ -36,7 +36,7 @@ namespace Nz::ShaderAst
struct Options
{
std::function<const ConstantValue&(std::size_t constantId)> constantQueryCallback;
std::function<const ConstantValue*(std::size_t constantId)> constantQueryCallback;
};
protected:

View File

@@ -45,6 +45,7 @@ NAZARA_SHADERAST_EXPRESSION(IntrinsicExpression)
NAZARA_SHADERAST_EXPRESSION(IntrinsicFunctionExpression)
NAZARA_SHADERAST_EXPRESSION(StructTypeExpression)
NAZARA_SHADERAST_EXPRESSION(SwizzleExpression)
NAZARA_SHADERAST_EXPRESSION(TypeExpression)
NAZARA_SHADERAST_EXPRESSION(VariableValueExpression)
NAZARA_SHADERAST_EXPRESSION(UnaryExpression)
NAZARA_SHADERAST_STATEMENT(BranchStatement)

View File

@@ -37,6 +37,7 @@ namespace Nz::ShaderAst
void Visit(IntrinsicFunctionExpression& node) override;
void Visit(StructTypeExpression& node) override;
void Visit(SwizzleExpression& node) override;
void Visit(TypeExpression& node) override;
void Visit(VariableValueExpression& node) override;
void Visit(UnaryExpression& node) override;

View File

@@ -40,6 +40,7 @@ namespace Nz::ShaderAst
void Serialize(IntrinsicFunctionExpression& node);
void Serialize(StructTypeExpression& node);
void Serialize(SwizzleExpression& node);
void Serialize(TypeExpression& node);
void Serialize(VariableValueExpression& node);
void Serialize(UnaryExpression& node);
void SerializeExpressionCommon(Expression& expr);

View File

@@ -48,6 +48,7 @@ namespace Nz::ShaderAst
void Visit(IntrinsicFunctionExpression& node) override;
void Visit(StructTypeExpression& node) override;
void Visit(SwizzleExpression& node) override;
void Visit(TypeExpression& node) override;
void Visit(VariableValueExpression& node) override;
void Visit(UnaryExpression& node) override;

View File

@@ -37,7 +37,7 @@ namespace Nz::ShaderAst
using ConstantValue = TypeListInstantiate<ConstantTypes, std::variant>;
NAZARA_SHADER_API ExpressionType GetExpressionType(const ConstantValue& constant);
NAZARA_SHADER_API ExpressionType GetConstantType(const ConstantValue& constant);
}
#endif // NAZARA_SHADER_AST_CONSTANTVALUE_HPP

View File

@@ -215,6 +215,14 @@ namespace Nz::ShaderAst
ExpressionPtr expression;
};
struct NAZARA_SHADER_API TypeExpression : Expression
{
NodeType GetType() const override;
void Visit(AstExpressionVisitor& visitor) override;
std::size_t typeId;
};
struct NAZARA_SHADER_API VariableValueExpression : Expression
{
NodeType GetType() const override;
@@ -462,8 +470,8 @@ namespace Nz::ShaderAst
#include <Nazara/Shader/Ast/AstNodeList.hpp>
inline const ExpressionType& GetExpressionType(Expression& expr);
inline ExpressionType& GetExpressionTypeMut(Expression& expr);
inline const ExpressionType* GetExpressionType(Expression& expr);
inline ExpressionType* GetExpressionTypeMut(Expression& expr);
inline bool IsExpression(NodeType nodeType);
inline bool IsStatement(NodeType nodeType);

View File

@@ -7,16 +7,14 @@
namespace Nz::ShaderAst
{
inline const ExpressionType& GetExpressionType(Expression& expr)
inline const ExpressionType* GetExpressionType(Expression& expr)
{
assert(expr.cachedExpressionType);
return expr.cachedExpressionType.value();
return (expr.cachedExpressionType) ? &expr.cachedExpressionType.value() : nullptr;
}
inline ExpressionType& GetExpressionTypeMut(Expression& expr)
inline ExpressionType* GetExpressionTypeMut(Expression& expr)
{
assert(expr.cachedExpressionType);
return expr.cachedExpressionType.value();
return (expr.cachedExpressionType) ? &expr.cachedExpressionType.value() : nullptr;
}
inline const ExpressionType& ResolveAlias(const ExpressionType& exprType)

View File

@@ -46,6 +46,7 @@ namespace Nz::ShaderAst
std::shared_ptr<ShaderModuleResolver> moduleResolver;
std::unordered_set<std::string> reservedIdentifiers;
std::unordered_map<UInt32, ConstantValue> optionValues;
bool allowPartialSanitization = false;
bool makeVariableNameUnique = false;
bool reduceLoopsToWhile = false;
bool removeAliases = false;
@@ -60,6 +61,7 @@ namespace Nz::ShaderAst
private:
enum class IdentifierCategory;
enum class ValidationResult;
struct AstError;
struct CurrentFunctionData;
struct Environment;
@@ -110,10 +112,12 @@ namespace Nz::ShaderAst
template<typename F> const IdentifierData* FindIdentifier(const std::string_view& identifierName, F&& functor) const;
const IdentifierData* FindIdentifier(const Environment& environment, const std::string_view& identifierName) const;
template<typename F> const IdentifierData* FindIdentifier(const Environment& environment, const std::string_view& identifierName, F&& functor) const;
TypeParameter FindTypeParameter(const std::string_view& identifierName) const;
ExpressionPtr HandleIdentifier(const IdentifierData* identifierData);
const ExpressionType* GetExpressionType(Expression& expr) const;
const ExpressionType& GetExpressionTypeSecure(Expression& expr) const;
Expression& MandatoryExpr(const ExpressionPtr& node) const;
Statement& MandatoryStatement(const StatementPtr& node) const;
@@ -122,8 +126,9 @@ namespace Nz::ShaderAst
ExpressionPtr CacheResult(ExpressionPtr expression);
ConstantValue ComputeConstantValue(Expression& expr) const;
template<typename T> const T& ComputeExprValue(ExpressionValue<T>& attribute) const;
std::optional<ConstantValue> ComputeConstantValue(Expression& expr) const;
template<typename T> ValidationResult ComputeExprValue(ExpressionValue<T>& attribute) const;
template<typename T> ValidationResult ComputeExprValue(const ExpressionValue<T>& attribute, ExpressionValue<T>& targetAttribute);
template<typename T> std::unique_ptr<T> PropagateConstants(T& node) const;
void PreregisterIndices(const Module& module);
@@ -131,49 +136,49 @@ namespace Nz::ShaderAst
void RegisterBuiltin();
std::size_t RegisterAlias(std::string name, IdentifierData aliasData, std::optional<std::size_t> index = {});
std::size_t RegisterConstant(std::string name, ConstantValue value, std::optional<std::size_t> index = {});
std::size_t RegisterFunction(std::string name, FunctionData funcData, std::optional<std::size_t> index = {});
std::size_t RegisterAlias(std::string name, std::optional<IdentifierData> aliasData, std::optional<std::size_t> index = {});
std::size_t RegisterConstant(std::string name, std::optional<ConstantValue> value, std::optional<std::size_t> index = {});
std::size_t RegisterFunction(std::string name, std::optional<FunctionData> funcData, std::optional<std::size_t> index = {});
std::size_t RegisterIntrinsic(std::string name, IntrinsicType type);
std::size_t RegisterModule(std::string moduleIdentifier, std::size_t moduleIndex);
std::size_t RegisterStruct(std::string name, StructDescription* description, std::optional<std::size_t> index = {});
std::size_t RegisterType(std::string name, ExpressionType expressionType, std::optional<std::size_t> index = {});
std::size_t RegisterType(std::string name, PartialType partialType, std::optional<std::size_t> index = {});
std::size_t RegisterVariable(std::string name, ExpressionType type, std::optional<std::size_t> index = {});
std::size_t RegisterStruct(std::string name, std::optional<StructDescription*> description, std::optional<std::size_t> index = {});
std::size_t RegisterType(std::string name, std::optional<ExpressionType> expressionType, std::optional<std::size_t> index = {});
std::size_t RegisterType(std::string name, std::optional<PartialType> partialType, std::optional<std::size_t> index = {});
void RegisterUnresolved(std::string name);
std::size_t RegisterVariable(std::string name, std::optional<ExpressionType> type, std::optional<std::size_t> index = {});
const IdentifierData* ResolveAliasIdentifier(const IdentifierData* identifier) const;
void ResolveFunctions();
const ExpressionPtr& ResolveCondExpression(ConditionalExpression& node);
std::size_t ResolveStruct(const AliasType& aliasType);
std::size_t ResolveStruct(const ExpressionType& exprType);
std::size_t ResolveStruct(const IdentifierType& identifierType);
std::size_t ResolveStruct(const StructType& structType);
std::size_t ResolveStruct(const UniformType& uniformType);
ExpressionType ResolveType(const ExpressionType& exprType, bool resolveAlias = false);
ExpressionType ResolveType(const ExpressionValue<ExpressionType>& exprTypeValue, bool resolveAlias = false);
std::optional<ExpressionType> ResolveTypeExpr(const ExpressionValue<ExpressionType>& exprTypeValue, bool resolveAlias = false);
void SanitizeIdentifier(std::string& identifier);
MultiStatementPtr SanitizeInternal(MultiStatement& rootNode, std::string* error);
void TypeMustMatch(const ExpressionPtr& left, const ExpressionPtr& right) const;
ValidationResult TypeMustMatch(const ExpressionPtr& left, const ExpressionPtr& right) const;
void TypeMustMatch(const ExpressionType& left, const ExpressionType& right) const;
StatementPtr Unscope(StatementPtr node);
void Validate(DeclareAliasStatement& node);
void Validate(WhileStatement& node);
ValidationResult Validate(DeclareAliasStatement& node);
ValidationResult Validate(WhileStatement& node);
void Validate(AccessIndexExpression& node);
void Validate(AssignExpression& node);
void Validate(BinaryExpression& node);
void Validate(CallFunctionExpression& node);
void Validate(CastExpression& node);
void Validate(DeclareVariableStatement& node);
void Validate(IntrinsicExpression& node);
void Validate(SwizzleExpression& node);
void Validate(UnaryExpression& node);
void Validate(VariableValueExpression& node);
ExpressionType ValidateBinaryOp(BinaryType op, const ExpressionPtr& leftExpr, const ExpressionPtr& rightExpr);
ValidationResult Validate(AccessIndexExpression& node);
ValidationResult Validate(AssignExpression& node);
ValidationResult Validate(BinaryExpression& node);
ValidationResult Validate(CallFunctionExpression& node);
ValidationResult Validate(CastExpression& node);
ValidationResult Validate(DeclareVariableStatement& node);
ValidationResult Validate(IntrinsicExpression& node);
ValidationResult Validate(SwizzleExpression& node);
ValidationResult Validate(UnaryExpression& node);
ValidationResult Validate(VariableValueExpression& node);
ExpressionType ValidateBinaryOp(BinaryType op, const ExpressionType& leftExprType, const ExpressionType& rightExprType);
enum class IdentifierCategory
{
@@ -184,9 +189,16 @@ namespace Nz::ShaderAst
Module,
Struct,
Type,
Unresolved,
Variable
};
enum class ValidationResult
{
Validated,
Unresolved
};
struct FunctionData
{
Bitset<> calledByFunctions;

View File

@@ -173,7 +173,7 @@ namespace Nz::ShaderBuilder
{
auto constantNode = std::make_unique<ShaderAst::ConstantValueExpression>();
constantNode->value = std::move(value);
constantNode->cachedExpressionType = ShaderAst::GetExpressionType(constantNode->value);
constantNode->cachedExpressionType = ShaderAst::GetConstantType(constantNode->value);
return constantNode;
}