Shader: Add support for partial sanitization
This commit is contained in:
@@ -65,7 +65,7 @@ namespace Nz
|
||||
NazaraSignal(OnShaderUpdated, UberShader* /*uberShader*/);
|
||||
|
||||
private:
|
||||
void Validate(ShaderAst::Module& module);
|
||||
ShaderAst::ModulePtr Validate(const ShaderAst::Module& module, std::unordered_map<std::string, Option>* options);
|
||||
|
||||
NazaraSlot(ShaderModuleResolver, OnModuleUpdated, m_onShaderModuleUpdated);
|
||||
|
||||
|
||||
@@ -57,6 +57,7 @@ namespace Nz::ShaderAst
|
||||
virtual ExpressionPtr Clone(IntrinsicFunctionExpression& node);
|
||||
virtual ExpressionPtr Clone(StructTypeExpression& node);
|
||||
virtual ExpressionPtr Clone(SwizzleExpression& node);
|
||||
virtual ExpressionPtr Clone(TypeExpression& node);
|
||||
virtual ExpressionPtr Clone(VariableValueExpression& node);
|
||||
virtual ExpressionPtr Clone(UnaryExpression& node);
|
||||
|
||||
|
||||
@@ -49,6 +49,7 @@ namespace Nz::ShaderAst
|
||||
inline bool Compare(const IntrinsicFunctionExpression& lhs, const IntrinsicFunctionExpression& rhs);
|
||||
inline bool Compare(const StructTypeExpression& lhs, const StructTypeExpression& rhs);
|
||||
inline bool Compare(const SwizzleExpression& lhs, const SwizzleExpression& rhs);
|
||||
inline bool Compare(const TypeExpression& lhs, const TypeExpression& rhs);
|
||||
inline bool Compare(const VariableValueExpression& lhs, const VariableValueExpression& rhs);
|
||||
inline bool Compare(const UnaryExpression& lhs, const UnaryExpression& rhs);
|
||||
|
||||
|
||||
@@ -407,6 +407,14 @@ namespace Nz::ShaderAst
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Compare(const TypeExpression& lhs, const TypeExpression& rhs)
|
||||
{
|
||||
if (!Compare(lhs.typeId, rhs.typeId))
|
||||
return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
inline bool Compare(const VariableValueExpression& lhs, const VariableValueExpression& rhs)
|
||||
{
|
||||
if (!Compare(lhs.variableId, rhs.variableId))
|
||||
|
||||
@@ -36,7 +36,7 @@ namespace Nz::ShaderAst
|
||||
|
||||
struct Options
|
||||
{
|
||||
std::function<const ConstantValue&(std::size_t constantId)> constantQueryCallback;
|
||||
std::function<const ConstantValue*(std::size_t constantId)> constantQueryCallback;
|
||||
};
|
||||
|
||||
protected:
|
||||
|
||||
@@ -45,6 +45,7 @@ NAZARA_SHADERAST_EXPRESSION(IntrinsicExpression)
|
||||
NAZARA_SHADERAST_EXPRESSION(IntrinsicFunctionExpression)
|
||||
NAZARA_SHADERAST_EXPRESSION(StructTypeExpression)
|
||||
NAZARA_SHADERAST_EXPRESSION(SwizzleExpression)
|
||||
NAZARA_SHADERAST_EXPRESSION(TypeExpression)
|
||||
NAZARA_SHADERAST_EXPRESSION(VariableValueExpression)
|
||||
NAZARA_SHADERAST_EXPRESSION(UnaryExpression)
|
||||
NAZARA_SHADERAST_STATEMENT(BranchStatement)
|
||||
|
||||
@@ -37,6 +37,7 @@ namespace Nz::ShaderAst
|
||||
void Visit(IntrinsicFunctionExpression& node) override;
|
||||
void Visit(StructTypeExpression& node) override;
|
||||
void Visit(SwizzleExpression& node) override;
|
||||
void Visit(TypeExpression& node) override;
|
||||
void Visit(VariableValueExpression& node) override;
|
||||
void Visit(UnaryExpression& node) override;
|
||||
|
||||
|
||||
@@ -40,6 +40,7 @@ namespace Nz::ShaderAst
|
||||
void Serialize(IntrinsicFunctionExpression& node);
|
||||
void Serialize(StructTypeExpression& node);
|
||||
void Serialize(SwizzleExpression& node);
|
||||
void Serialize(TypeExpression& node);
|
||||
void Serialize(VariableValueExpression& node);
|
||||
void Serialize(UnaryExpression& node);
|
||||
void SerializeExpressionCommon(Expression& expr);
|
||||
|
||||
@@ -48,6 +48,7 @@ namespace Nz::ShaderAst
|
||||
void Visit(IntrinsicFunctionExpression& node) override;
|
||||
void Visit(StructTypeExpression& node) override;
|
||||
void Visit(SwizzleExpression& node) override;
|
||||
void Visit(TypeExpression& node) override;
|
||||
void Visit(VariableValueExpression& node) override;
|
||||
void Visit(UnaryExpression& node) override;
|
||||
|
||||
|
||||
@@ -37,7 +37,7 @@ namespace Nz::ShaderAst
|
||||
|
||||
using ConstantValue = TypeListInstantiate<ConstantTypes, std::variant>;
|
||||
|
||||
NAZARA_SHADER_API ExpressionType GetExpressionType(const ConstantValue& constant);
|
||||
NAZARA_SHADER_API ExpressionType GetConstantType(const ConstantValue& constant);
|
||||
}
|
||||
|
||||
#endif // NAZARA_SHADER_AST_CONSTANTVALUE_HPP
|
||||
|
||||
@@ -215,6 +215,14 @@ namespace Nz::ShaderAst
|
||||
ExpressionPtr expression;
|
||||
};
|
||||
|
||||
struct NAZARA_SHADER_API TypeExpression : Expression
|
||||
{
|
||||
NodeType GetType() const override;
|
||||
void Visit(AstExpressionVisitor& visitor) override;
|
||||
|
||||
std::size_t typeId;
|
||||
};
|
||||
|
||||
struct NAZARA_SHADER_API VariableValueExpression : Expression
|
||||
{
|
||||
NodeType GetType() const override;
|
||||
@@ -462,8 +470,8 @@ namespace Nz::ShaderAst
|
||||
|
||||
#include <Nazara/Shader/Ast/AstNodeList.hpp>
|
||||
|
||||
inline const ExpressionType& GetExpressionType(Expression& expr);
|
||||
inline ExpressionType& GetExpressionTypeMut(Expression& expr);
|
||||
inline const ExpressionType* GetExpressionType(Expression& expr);
|
||||
inline ExpressionType* GetExpressionTypeMut(Expression& expr);
|
||||
inline bool IsExpression(NodeType nodeType);
|
||||
inline bool IsStatement(NodeType nodeType);
|
||||
|
||||
|
||||
@@ -7,16 +7,14 @@
|
||||
|
||||
namespace Nz::ShaderAst
|
||||
{
|
||||
inline const ExpressionType& GetExpressionType(Expression& expr)
|
||||
inline const ExpressionType* GetExpressionType(Expression& expr)
|
||||
{
|
||||
assert(expr.cachedExpressionType);
|
||||
return expr.cachedExpressionType.value();
|
||||
return (expr.cachedExpressionType) ? &expr.cachedExpressionType.value() : nullptr;
|
||||
}
|
||||
|
||||
inline ExpressionType& GetExpressionTypeMut(Expression& expr)
|
||||
inline ExpressionType* GetExpressionTypeMut(Expression& expr)
|
||||
{
|
||||
assert(expr.cachedExpressionType);
|
||||
return expr.cachedExpressionType.value();
|
||||
return (expr.cachedExpressionType) ? &expr.cachedExpressionType.value() : nullptr;
|
||||
}
|
||||
|
||||
inline const ExpressionType& ResolveAlias(const ExpressionType& exprType)
|
||||
|
||||
@@ -46,6 +46,7 @@ namespace Nz::ShaderAst
|
||||
std::shared_ptr<ShaderModuleResolver> moduleResolver;
|
||||
std::unordered_set<std::string> reservedIdentifiers;
|
||||
std::unordered_map<UInt32, ConstantValue> optionValues;
|
||||
bool allowPartialSanitization = false;
|
||||
bool makeVariableNameUnique = false;
|
||||
bool reduceLoopsToWhile = false;
|
||||
bool removeAliases = false;
|
||||
@@ -60,6 +61,7 @@ namespace Nz::ShaderAst
|
||||
|
||||
private:
|
||||
enum class IdentifierCategory;
|
||||
enum class ValidationResult;
|
||||
struct AstError;
|
||||
struct CurrentFunctionData;
|
||||
struct Environment;
|
||||
@@ -110,10 +112,12 @@ namespace Nz::ShaderAst
|
||||
template<typename F> const IdentifierData* FindIdentifier(const std::string_view& identifierName, F&& functor) const;
|
||||
const IdentifierData* FindIdentifier(const Environment& environment, const std::string_view& identifierName) const;
|
||||
template<typename F> const IdentifierData* FindIdentifier(const Environment& environment, const std::string_view& identifierName, F&& functor) const;
|
||||
TypeParameter FindTypeParameter(const std::string_view& identifierName) const;
|
||||
|
||||
ExpressionPtr HandleIdentifier(const IdentifierData* identifierData);
|
||||
|
||||
const ExpressionType* GetExpressionType(Expression& expr) const;
|
||||
const ExpressionType& GetExpressionTypeSecure(Expression& expr) const;
|
||||
|
||||
Expression& MandatoryExpr(const ExpressionPtr& node) const;
|
||||
Statement& MandatoryStatement(const StatementPtr& node) const;
|
||||
|
||||
@@ -122,8 +126,9 @@ namespace Nz::ShaderAst
|
||||
|
||||
ExpressionPtr CacheResult(ExpressionPtr expression);
|
||||
|
||||
ConstantValue ComputeConstantValue(Expression& expr) const;
|
||||
template<typename T> const T& ComputeExprValue(ExpressionValue<T>& attribute) const;
|
||||
std::optional<ConstantValue> ComputeConstantValue(Expression& expr) const;
|
||||
template<typename T> ValidationResult ComputeExprValue(ExpressionValue<T>& attribute) const;
|
||||
template<typename T> ValidationResult ComputeExprValue(const ExpressionValue<T>& attribute, ExpressionValue<T>& targetAttribute);
|
||||
template<typename T> std::unique_ptr<T> PropagateConstants(T& node) const;
|
||||
|
||||
void PreregisterIndices(const Module& module);
|
||||
@@ -131,49 +136,49 @@ namespace Nz::ShaderAst
|
||||
|
||||
void RegisterBuiltin();
|
||||
|
||||
std::size_t RegisterAlias(std::string name, IdentifierData aliasData, std::optional<std::size_t> index = {});
|
||||
std::size_t RegisterConstant(std::string name, ConstantValue value, std::optional<std::size_t> index = {});
|
||||
std::size_t RegisterFunction(std::string name, FunctionData funcData, std::optional<std::size_t> index = {});
|
||||
std::size_t RegisterAlias(std::string name, std::optional<IdentifierData> aliasData, std::optional<std::size_t> index = {});
|
||||
std::size_t RegisterConstant(std::string name, std::optional<ConstantValue> value, std::optional<std::size_t> index = {});
|
||||
std::size_t RegisterFunction(std::string name, std::optional<FunctionData> funcData, std::optional<std::size_t> index = {});
|
||||
std::size_t RegisterIntrinsic(std::string name, IntrinsicType type);
|
||||
std::size_t RegisterModule(std::string moduleIdentifier, std::size_t moduleIndex);
|
||||
std::size_t RegisterStruct(std::string name, StructDescription* description, std::optional<std::size_t> index = {});
|
||||
std::size_t RegisterType(std::string name, ExpressionType expressionType, std::optional<std::size_t> index = {});
|
||||
std::size_t RegisterType(std::string name, PartialType partialType, std::optional<std::size_t> index = {});
|
||||
std::size_t RegisterVariable(std::string name, ExpressionType type, std::optional<std::size_t> index = {});
|
||||
std::size_t RegisterStruct(std::string name, std::optional<StructDescription*> description, std::optional<std::size_t> index = {});
|
||||
std::size_t RegisterType(std::string name, std::optional<ExpressionType> expressionType, std::optional<std::size_t> index = {});
|
||||
std::size_t RegisterType(std::string name, std::optional<PartialType> partialType, std::optional<std::size_t> index = {});
|
||||
void RegisterUnresolved(std::string name);
|
||||
std::size_t RegisterVariable(std::string name, std::optional<ExpressionType> type, std::optional<std::size_t> index = {});
|
||||
|
||||
const IdentifierData* ResolveAliasIdentifier(const IdentifierData* identifier) const;
|
||||
void ResolveFunctions();
|
||||
const ExpressionPtr& ResolveCondExpression(ConditionalExpression& node);
|
||||
std::size_t ResolveStruct(const AliasType& aliasType);
|
||||
std::size_t ResolveStruct(const ExpressionType& exprType);
|
||||
std::size_t ResolveStruct(const IdentifierType& identifierType);
|
||||
std::size_t ResolveStruct(const StructType& structType);
|
||||
std::size_t ResolveStruct(const UniformType& uniformType);
|
||||
ExpressionType ResolveType(const ExpressionType& exprType, bool resolveAlias = false);
|
||||
ExpressionType ResolveType(const ExpressionValue<ExpressionType>& exprTypeValue, bool resolveAlias = false);
|
||||
std::optional<ExpressionType> ResolveTypeExpr(const ExpressionValue<ExpressionType>& exprTypeValue, bool resolveAlias = false);
|
||||
|
||||
void SanitizeIdentifier(std::string& identifier);
|
||||
MultiStatementPtr SanitizeInternal(MultiStatement& rootNode, std::string* error);
|
||||
|
||||
void TypeMustMatch(const ExpressionPtr& left, const ExpressionPtr& right) const;
|
||||
ValidationResult TypeMustMatch(const ExpressionPtr& left, const ExpressionPtr& right) const;
|
||||
void TypeMustMatch(const ExpressionType& left, const ExpressionType& right) const;
|
||||
|
||||
StatementPtr Unscope(StatementPtr node);
|
||||
|
||||
void Validate(DeclareAliasStatement& node);
|
||||
void Validate(WhileStatement& node);
|
||||
ValidationResult Validate(DeclareAliasStatement& node);
|
||||
ValidationResult Validate(WhileStatement& node);
|
||||
|
||||
void Validate(AccessIndexExpression& node);
|
||||
void Validate(AssignExpression& node);
|
||||
void Validate(BinaryExpression& node);
|
||||
void Validate(CallFunctionExpression& node);
|
||||
void Validate(CastExpression& node);
|
||||
void Validate(DeclareVariableStatement& node);
|
||||
void Validate(IntrinsicExpression& node);
|
||||
void Validate(SwizzleExpression& node);
|
||||
void Validate(UnaryExpression& node);
|
||||
void Validate(VariableValueExpression& node);
|
||||
ExpressionType ValidateBinaryOp(BinaryType op, const ExpressionPtr& leftExpr, const ExpressionPtr& rightExpr);
|
||||
ValidationResult Validate(AccessIndexExpression& node);
|
||||
ValidationResult Validate(AssignExpression& node);
|
||||
ValidationResult Validate(BinaryExpression& node);
|
||||
ValidationResult Validate(CallFunctionExpression& node);
|
||||
ValidationResult Validate(CastExpression& node);
|
||||
ValidationResult Validate(DeclareVariableStatement& node);
|
||||
ValidationResult Validate(IntrinsicExpression& node);
|
||||
ValidationResult Validate(SwizzleExpression& node);
|
||||
ValidationResult Validate(UnaryExpression& node);
|
||||
ValidationResult Validate(VariableValueExpression& node);
|
||||
ExpressionType ValidateBinaryOp(BinaryType op, const ExpressionType& leftExprType, const ExpressionType& rightExprType);
|
||||
|
||||
enum class IdentifierCategory
|
||||
{
|
||||
@@ -184,9 +189,16 @@ namespace Nz::ShaderAst
|
||||
Module,
|
||||
Struct,
|
||||
Type,
|
||||
Unresolved,
|
||||
Variable
|
||||
};
|
||||
|
||||
enum class ValidationResult
|
||||
{
|
||||
Validated,
|
||||
Unresolved
|
||||
};
|
||||
|
||||
struct FunctionData
|
||||
{
|
||||
Bitset<> calledByFunctions;
|
||||
|
||||
@@ -173,7 +173,7 @@ namespace Nz::ShaderBuilder
|
||||
{
|
||||
auto constantNode = std::make_unique<ShaderAst::ConstantValueExpression>();
|
||||
constantNode->value = std::move(value);
|
||||
constantNode->cachedExpressionType = ShaderAst::GetExpressionType(constantNode->value);
|
||||
constantNode->cachedExpressionType = ShaderAst::GetConstantType(constantNode->value);
|
||||
|
||||
return constantNode;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user