Shader: Add support for partial sanitization
This commit is contained in:
parent
a54f70fd24
commit
8146ec251a
|
|
@ -65,7 +65,7 @@ namespace Nz
|
||||||
NazaraSignal(OnShaderUpdated, UberShader* /*uberShader*/);
|
NazaraSignal(OnShaderUpdated, UberShader* /*uberShader*/);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void Validate(ShaderAst::Module& module);
|
ShaderAst::ModulePtr Validate(const ShaderAst::Module& module, std::unordered_map<std::string, Option>* options);
|
||||||
|
|
||||||
NazaraSlot(ShaderModuleResolver, OnModuleUpdated, m_onShaderModuleUpdated);
|
NazaraSlot(ShaderModuleResolver, OnModuleUpdated, m_onShaderModuleUpdated);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -57,6 +57,7 @@ namespace Nz::ShaderAst
|
||||||
virtual ExpressionPtr Clone(IntrinsicFunctionExpression& node);
|
virtual ExpressionPtr Clone(IntrinsicFunctionExpression& node);
|
||||||
virtual ExpressionPtr Clone(StructTypeExpression& node);
|
virtual ExpressionPtr Clone(StructTypeExpression& node);
|
||||||
virtual ExpressionPtr Clone(SwizzleExpression& node);
|
virtual ExpressionPtr Clone(SwizzleExpression& node);
|
||||||
|
virtual ExpressionPtr Clone(TypeExpression& node);
|
||||||
virtual ExpressionPtr Clone(VariableValueExpression& node);
|
virtual ExpressionPtr Clone(VariableValueExpression& node);
|
||||||
virtual ExpressionPtr Clone(UnaryExpression& node);
|
virtual ExpressionPtr Clone(UnaryExpression& node);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -49,6 +49,7 @@ namespace Nz::ShaderAst
|
||||||
inline bool Compare(const IntrinsicFunctionExpression& lhs, const IntrinsicFunctionExpression& rhs);
|
inline bool Compare(const IntrinsicFunctionExpression& lhs, const IntrinsicFunctionExpression& rhs);
|
||||||
inline bool Compare(const StructTypeExpression& lhs, const StructTypeExpression& rhs);
|
inline bool Compare(const StructTypeExpression& lhs, const StructTypeExpression& rhs);
|
||||||
inline bool Compare(const SwizzleExpression& lhs, const SwizzleExpression& rhs);
|
inline bool Compare(const SwizzleExpression& lhs, const SwizzleExpression& rhs);
|
||||||
|
inline bool Compare(const TypeExpression& lhs, const TypeExpression& rhs);
|
||||||
inline bool Compare(const VariableValueExpression& lhs, const VariableValueExpression& rhs);
|
inline bool Compare(const VariableValueExpression& lhs, const VariableValueExpression& rhs);
|
||||||
inline bool Compare(const UnaryExpression& lhs, const UnaryExpression& rhs);
|
inline bool Compare(const UnaryExpression& lhs, const UnaryExpression& rhs);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -407,6 +407,14 @@ namespace Nz::ShaderAst
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool Compare(const TypeExpression& lhs, const TypeExpression& rhs)
|
||||||
|
{
|
||||||
|
if (!Compare(lhs.typeId, rhs.typeId))
|
||||||
|
return false;
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
inline bool Compare(const VariableValueExpression& lhs, const VariableValueExpression& rhs)
|
inline bool Compare(const VariableValueExpression& lhs, const VariableValueExpression& rhs)
|
||||||
{
|
{
|
||||||
if (!Compare(lhs.variableId, rhs.variableId))
|
if (!Compare(lhs.variableId, rhs.variableId))
|
||||||
|
|
|
||||||
|
|
@ -36,7 +36,7 @@ namespace Nz::ShaderAst
|
||||||
|
|
||||||
struct Options
|
struct Options
|
||||||
{
|
{
|
||||||
std::function<const ConstantValue&(std::size_t constantId)> constantQueryCallback;
|
std::function<const ConstantValue*(std::size_t constantId)> constantQueryCallback;
|
||||||
};
|
};
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
|
|
||||||
|
|
@ -45,6 +45,7 @@ NAZARA_SHADERAST_EXPRESSION(IntrinsicExpression)
|
||||||
NAZARA_SHADERAST_EXPRESSION(IntrinsicFunctionExpression)
|
NAZARA_SHADERAST_EXPRESSION(IntrinsicFunctionExpression)
|
||||||
NAZARA_SHADERAST_EXPRESSION(StructTypeExpression)
|
NAZARA_SHADERAST_EXPRESSION(StructTypeExpression)
|
||||||
NAZARA_SHADERAST_EXPRESSION(SwizzleExpression)
|
NAZARA_SHADERAST_EXPRESSION(SwizzleExpression)
|
||||||
|
NAZARA_SHADERAST_EXPRESSION(TypeExpression)
|
||||||
NAZARA_SHADERAST_EXPRESSION(VariableValueExpression)
|
NAZARA_SHADERAST_EXPRESSION(VariableValueExpression)
|
||||||
NAZARA_SHADERAST_EXPRESSION(UnaryExpression)
|
NAZARA_SHADERAST_EXPRESSION(UnaryExpression)
|
||||||
NAZARA_SHADERAST_STATEMENT(BranchStatement)
|
NAZARA_SHADERAST_STATEMENT(BranchStatement)
|
||||||
|
|
|
||||||
|
|
@ -37,6 +37,7 @@ namespace Nz::ShaderAst
|
||||||
void Visit(IntrinsicFunctionExpression& node) override;
|
void Visit(IntrinsicFunctionExpression& node) override;
|
||||||
void Visit(StructTypeExpression& node) override;
|
void Visit(StructTypeExpression& node) override;
|
||||||
void Visit(SwizzleExpression& node) override;
|
void Visit(SwizzleExpression& node) override;
|
||||||
|
void Visit(TypeExpression& node) override;
|
||||||
void Visit(VariableValueExpression& node) override;
|
void Visit(VariableValueExpression& node) override;
|
||||||
void Visit(UnaryExpression& node) override;
|
void Visit(UnaryExpression& node) override;
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -40,6 +40,7 @@ namespace Nz::ShaderAst
|
||||||
void Serialize(IntrinsicFunctionExpression& node);
|
void Serialize(IntrinsicFunctionExpression& node);
|
||||||
void Serialize(StructTypeExpression& node);
|
void Serialize(StructTypeExpression& node);
|
||||||
void Serialize(SwizzleExpression& node);
|
void Serialize(SwizzleExpression& node);
|
||||||
|
void Serialize(TypeExpression& node);
|
||||||
void Serialize(VariableValueExpression& node);
|
void Serialize(VariableValueExpression& node);
|
||||||
void Serialize(UnaryExpression& node);
|
void Serialize(UnaryExpression& node);
|
||||||
void SerializeExpressionCommon(Expression& expr);
|
void SerializeExpressionCommon(Expression& expr);
|
||||||
|
|
|
||||||
|
|
@ -48,6 +48,7 @@ namespace Nz::ShaderAst
|
||||||
void Visit(IntrinsicFunctionExpression& node) override;
|
void Visit(IntrinsicFunctionExpression& node) override;
|
||||||
void Visit(StructTypeExpression& node) override;
|
void Visit(StructTypeExpression& node) override;
|
||||||
void Visit(SwizzleExpression& node) override;
|
void Visit(SwizzleExpression& node) override;
|
||||||
|
void Visit(TypeExpression& node) override;
|
||||||
void Visit(VariableValueExpression& node) override;
|
void Visit(VariableValueExpression& node) override;
|
||||||
void Visit(UnaryExpression& node) override;
|
void Visit(UnaryExpression& node) override;
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -37,7 +37,7 @@ namespace Nz::ShaderAst
|
||||||
|
|
||||||
using ConstantValue = TypeListInstantiate<ConstantTypes, std::variant>;
|
using ConstantValue = TypeListInstantiate<ConstantTypes, std::variant>;
|
||||||
|
|
||||||
NAZARA_SHADER_API ExpressionType GetExpressionType(const ConstantValue& constant);
|
NAZARA_SHADER_API ExpressionType GetConstantType(const ConstantValue& constant);
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif // NAZARA_SHADER_AST_CONSTANTVALUE_HPP
|
#endif // NAZARA_SHADER_AST_CONSTANTVALUE_HPP
|
||||||
|
|
|
||||||
|
|
@ -215,6 +215,14 @@ namespace Nz::ShaderAst
|
||||||
ExpressionPtr expression;
|
ExpressionPtr expression;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct NAZARA_SHADER_API TypeExpression : Expression
|
||||||
|
{
|
||||||
|
NodeType GetType() const override;
|
||||||
|
void Visit(AstExpressionVisitor& visitor) override;
|
||||||
|
|
||||||
|
std::size_t typeId;
|
||||||
|
};
|
||||||
|
|
||||||
struct NAZARA_SHADER_API VariableValueExpression : Expression
|
struct NAZARA_SHADER_API VariableValueExpression : Expression
|
||||||
{
|
{
|
||||||
NodeType GetType() const override;
|
NodeType GetType() const override;
|
||||||
|
|
@ -462,8 +470,8 @@ namespace Nz::ShaderAst
|
||||||
|
|
||||||
#include <Nazara/Shader/Ast/AstNodeList.hpp>
|
#include <Nazara/Shader/Ast/AstNodeList.hpp>
|
||||||
|
|
||||||
inline const ExpressionType& GetExpressionType(Expression& expr);
|
inline const ExpressionType* GetExpressionType(Expression& expr);
|
||||||
inline ExpressionType& GetExpressionTypeMut(Expression& expr);
|
inline ExpressionType* GetExpressionTypeMut(Expression& expr);
|
||||||
inline bool IsExpression(NodeType nodeType);
|
inline bool IsExpression(NodeType nodeType);
|
||||||
inline bool IsStatement(NodeType nodeType);
|
inline bool IsStatement(NodeType nodeType);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -7,16 +7,14 @@
|
||||||
|
|
||||||
namespace Nz::ShaderAst
|
namespace Nz::ShaderAst
|
||||||
{
|
{
|
||||||
inline const ExpressionType& GetExpressionType(Expression& expr)
|
inline const ExpressionType* GetExpressionType(Expression& expr)
|
||||||
{
|
{
|
||||||
assert(expr.cachedExpressionType);
|
return (expr.cachedExpressionType) ? &expr.cachedExpressionType.value() : nullptr;
|
||||||
return expr.cachedExpressionType.value();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
inline ExpressionType& GetExpressionTypeMut(Expression& expr)
|
inline ExpressionType* GetExpressionTypeMut(Expression& expr)
|
||||||
{
|
{
|
||||||
assert(expr.cachedExpressionType);
|
return (expr.cachedExpressionType) ? &expr.cachedExpressionType.value() : nullptr;
|
||||||
return expr.cachedExpressionType.value();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
inline const ExpressionType& ResolveAlias(const ExpressionType& exprType)
|
inline const ExpressionType& ResolveAlias(const ExpressionType& exprType)
|
||||||
|
|
|
||||||
|
|
@ -46,6 +46,7 @@ namespace Nz::ShaderAst
|
||||||
std::shared_ptr<ShaderModuleResolver> moduleResolver;
|
std::shared_ptr<ShaderModuleResolver> moduleResolver;
|
||||||
std::unordered_set<std::string> reservedIdentifiers;
|
std::unordered_set<std::string> reservedIdentifiers;
|
||||||
std::unordered_map<UInt32, ConstantValue> optionValues;
|
std::unordered_map<UInt32, ConstantValue> optionValues;
|
||||||
|
bool allowPartialSanitization = false;
|
||||||
bool makeVariableNameUnique = false;
|
bool makeVariableNameUnique = false;
|
||||||
bool reduceLoopsToWhile = false;
|
bool reduceLoopsToWhile = false;
|
||||||
bool removeAliases = false;
|
bool removeAliases = false;
|
||||||
|
|
@ -60,6 +61,7 @@ namespace Nz::ShaderAst
|
||||||
|
|
||||||
private:
|
private:
|
||||||
enum class IdentifierCategory;
|
enum class IdentifierCategory;
|
||||||
|
enum class ValidationResult;
|
||||||
struct AstError;
|
struct AstError;
|
||||||
struct CurrentFunctionData;
|
struct CurrentFunctionData;
|
||||||
struct Environment;
|
struct Environment;
|
||||||
|
|
@ -110,10 +112,12 @@ namespace Nz::ShaderAst
|
||||||
template<typename F> const IdentifierData* FindIdentifier(const std::string_view& identifierName, F&& functor) const;
|
template<typename F> const IdentifierData* FindIdentifier(const std::string_view& identifierName, F&& functor) const;
|
||||||
const IdentifierData* FindIdentifier(const Environment& environment, const std::string_view& identifierName) const;
|
const IdentifierData* FindIdentifier(const Environment& environment, const std::string_view& identifierName) const;
|
||||||
template<typename F> const IdentifierData* FindIdentifier(const Environment& environment, const std::string_view& identifierName, F&& functor) const;
|
template<typename F> const IdentifierData* FindIdentifier(const Environment& environment, const std::string_view& identifierName, F&& functor) const;
|
||||||
TypeParameter FindTypeParameter(const std::string_view& identifierName) const;
|
|
||||||
|
|
||||||
ExpressionPtr HandleIdentifier(const IdentifierData* identifierData);
|
ExpressionPtr HandleIdentifier(const IdentifierData* identifierData);
|
||||||
|
|
||||||
|
const ExpressionType* GetExpressionType(Expression& expr) const;
|
||||||
|
const ExpressionType& GetExpressionTypeSecure(Expression& expr) const;
|
||||||
|
|
||||||
Expression& MandatoryExpr(const ExpressionPtr& node) const;
|
Expression& MandatoryExpr(const ExpressionPtr& node) const;
|
||||||
Statement& MandatoryStatement(const StatementPtr& node) const;
|
Statement& MandatoryStatement(const StatementPtr& node) const;
|
||||||
|
|
||||||
|
|
@ -122,8 +126,9 @@ namespace Nz::ShaderAst
|
||||||
|
|
||||||
ExpressionPtr CacheResult(ExpressionPtr expression);
|
ExpressionPtr CacheResult(ExpressionPtr expression);
|
||||||
|
|
||||||
ConstantValue ComputeConstantValue(Expression& expr) const;
|
std::optional<ConstantValue> ComputeConstantValue(Expression& expr) const;
|
||||||
template<typename T> const T& ComputeExprValue(ExpressionValue<T>& attribute) const;
|
template<typename T> ValidationResult ComputeExprValue(ExpressionValue<T>& attribute) const;
|
||||||
|
template<typename T> ValidationResult ComputeExprValue(const ExpressionValue<T>& attribute, ExpressionValue<T>& targetAttribute);
|
||||||
template<typename T> std::unique_ptr<T> PropagateConstants(T& node) const;
|
template<typename T> std::unique_ptr<T> PropagateConstants(T& node) const;
|
||||||
|
|
||||||
void PreregisterIndices(const Module& module);
|
void PreregisterIndices(const Module& module);
|
||||||
|
|
@ -131,49 +136,49 @@ namespace Nz::ShaderAst
|
||||||
|
|
||||||
void RegisterBuiltin();
|
void RegisterBuiltin();
|
||||||
|
|
||||||
std::size_t RegisterAlias(std::string name, IdentifierData aliasData, std::optional<std::size_t> index = {});
|
std::size_t RegisterAlias(std::string name, std::optional<IdentifierData> aliasData, std::optional<std::size_t> index = {});
|
||||||
std::size_t RegisterConstant(std::string name, ConstantValue value, std::optional<std::size_t> index = {});
|
std::size_t RegisterConstant(std::string name, std::optional<ConstantValue> value, std::optional<std::size_t> index = {});
|
||||||
std::size_t RegisterFunction(std::string name, FunctionData funcData, std::optional<std::size_t> index = {});
|
std::size_t RegisterFunction(std::string name, std::optional<FunctionData> funcData, std::optional<std::size_t> index = {});
|
||||||
std::size_t RegisterIntrinsic(std::string name, IntrinsicType type);
|
std::size_t RegisterIntrinsic(std::string name, IntrinsicType type);
|
||||||
std::size_t RegisterModule(std::string moduleIdentifier, std::size_t moduleIndex);
|
std::size_t RegisterModule(std::string moduleIdentifier, std::size_t moduleIndex);
|
||||||
std::size_t RegisterStruct(std::string name, StructDescription* description, std::optional<std::size_t> index = {});
|
std::size_t RegisterStruct(std::string name, std::optional<StructDescription*> description, std::optional<std::size_t> index = {});
|
||||||
std::size_t RegisterType(std::string name, ExpressionType expressionType, std::optional<std::size_t> index = {});
|
std::size_t RegisterType(std::string name, std::optional<ExpressionType> expressionType, std::optional<std::size_t> index = {});
|
||||||
std::size_t RegisterType(std::string name, PartialType partialType, std::optional<std::size_t> index = {});
|
std::size_t RegisterType(std::string name, std::optional<PartialType> partialType, std::optional<std::size_t> index = {});
|
||||||
std::size_t RegisterVariable(std::string name, ExpressionType type, std::optional<std::size_t> index = {});
|
void RegisterUnresolved(std::string name);
|
||||||
|
std::size_t RegisterVariable(std::string name, std::optional<ExpressionType> type, std::optional<std::size_t> index = {});
|
||||||
|
|
||||||
const IdentifierData* ResolveAliasIdentifier(const IdentifierData* identifier) const;
|
const IdentifierData* ResolveAliasIdentifier(const IdentifierData* identifier) const;
|
||||||
void ResolveFunctions();
|
void ResolveFunctions();
|
||||||
const ExpressionPtr& ResolveCondExpression(ConditionalExpression& node);
|
|
||||||
std::size_t ResolveStruct(const AliasType& aliasType);
|
std::size_t ResolveStruct(const AliasType& aliasType);
|
||||||
std::size_t ResolveStruct(const ExpressionType& exprType);
|
std::size_t ResolveStruct(const ExpressionType& exprType);
|
||||||
std::size_t ResolveStruct(const IdentifierType& identifierType);
|
std::size_t ResolveStruct(const IdentifierType& identifierType);
|
||||||
std::size_t ResolveStruct(const StructType& structType);
|
std::size_t ResolveStruct(const StructType& structType);
|
||||||
std::size_t ResolveStruct(const UniformType& uniformType);
|
std::size_t ResolveStruct(const UniformType& uniformType);
|
||||||
ExpressionType ResolveType(const ExpressionType& exprType, bool resolveAlias = false);
|
ExpressionType ResolveType(const ExpressionType& exprType, bool resolveAlias = false);
|
||||||
ExpressionType ResolveType(const ExpressionValue<ExpressionType>& exprTypeValue, bool resolveAlias = false);
|
std::optional<ExpressionType> ResolveTypeExpr(const ExpressionValue<ExpressionType>& exprTypeValue, bool resolveAlias = false);
|
||||||
|
|
||||||
void SanitizeIdentifier(std::string& identifier);
|
void SanitizeIdentifier(std::string& identifier);
|
||||||
MultiStatementPtr SanitizeInternal(MultiStatement& rootNode, std::string* error);
|
MultiStatementPtr SanitizeInternal(MultiStatement& rootNode, std::string* error);
|
||||||
|
|
||||||
void TypeMustMatch(const ExpressionPtr& left, const ExpressionPtr& right) const;
|
ValidationResult TypeMustMatch(const ExpressionPtr& left, const ExpressionPtr& right) const;
|
||||||
void TypeMustMatch(const ExpressionType& left, const ExpressionType& right) const;
|
void TypeMustMatch(const ExpressionType& left, const ExpressionType& right) const;
|
||||||
|
|
||||||
StatementPtr Unscope(StatementPtr node);
|
StatementPtr Unscope(StatementPtr node);
|
||||||
|
|
||||||
void Validate(DeclareAliasStatement& node);
|
ValidationResult Validate(DeclareAliasStatement& node);
|
||||||
void Validate(WhileStatement& node);
|
ValidationResult Validate(WhileStatement& node);
|
||||||
|
|
||||||
void Validate(AccessIndexExpression& node);
|
ValidationResult Validate(AccessIndexExpression& node);
|
||||||
void Validate(AssignExpression& node);
|
ValidationResult Validate(AssignExpression& node);
|
||||||
void Validate(BinaryExpression& node);
|
ValidationResult Validate(BinaryExpression& node);
|
||||||
void Validate(CallFunctionExpression& node);
|
ValidationResult Validate(CallFunctionExpression& node);
|
||||||
void Validate(CastExpression& node);
|
ValidationResult Validate(CastExpression& node);
|
||||||
void Validate(DeclareVariableStatement& node);
|
ValidationResult Validate(DeclareVariableStatement& node);
|
||||||
void Validate(IntrinsicExpression& node);
|
ValidationResult Validate(IntrinsicExpression& node);
|
||||||
void Validate(SwizzleExpression& node);
|
ValidationResult Validate(SwizzleExpression& node);
|
||||||
void Validate(UnaryExpression& node);
|
ValidationResult Validate(UnaryExpression& node);
|
||||||
void Validate(VariableValueExpression& node);
|
ValidationResult Validate(VariableValueExpression& node);
|
||||||
ExpressionType ValidateBinaryOp(BinaryType op, const ExpressionPtr& leftExpr, const ExpressionPtr& rightExpr);
|
ExpressionType ValidateBinaryOp(BinaryType op, const ExpressionType& leftExprType, const ExpressionType& rightExprType);
|
||||||
|
|
||||||
enum class IdentifierCategory
|
enum class IdentifierCategory
|
||||||
{
|
{
|
||||||
|
|
@ -184,9 +189,16 @@ namespace Nz::ShaderAst
|
||||||
Module,
|
Module,
|
||||||
Struct,
|
Struct,
|
||||||
Type,
|
Type,
|
||||||
|
Unresolved,
|
||||||
Variable
|
Variable
|
||||||
};
|
};
|
||||||
|
|
||||||
|
enum class ValidationResult
|
||||||
|
{
|
||||||
|
Validated,
|
||||||
|
Unresolved
|
||||||
|
};
|
||||||
|
|
||||||
struct FunctionData
|
struct FunctionData
|
||||||
{
|
{
|
||||||
Bitset<> calledByFunctions;
|
Bitset<> calledByFunctions;
|
||||||
|
|
|
||||||
|
|
@ -173,7 +173,7 @@ namespace Nz::ShaderBuilder
|
||||||
{
|
{
|
||||||
auto constantNode = std::make_unique<ShaderAst::ConstantValueExpression>();
|
auto constantNode = std::make_unique<ShaderAst::ConstantValueExpression>();
|
||||||
constantNode->value = std::move(value);
|
constantNode->value = std::move(value);
|
||||||
constantNode->cachedExpressionType = ShaderAst::GetExpressionType(constantNode->value);
|
constantNode->cachedExpressionType = ShaderAst::GetConstantType(constantNode->value);
|
||||||
|
|
||||||
return constantNode;
|
return constantNode;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -249,7 +249,7 @@ struct VertIn
|
||||||
}
|
}
|
||||||
|
|
||||||
[entry(vert), cond(Billboard)]
|
[entry(vert), cond(Billboard)]
|
||||||
fn billboardMain(input: VertIn) -> VertOut
|
fn billboardMain(input: VertIn) -> VertToFrag
|
||||||
{
|
{
|
||||||
let size = input.billboardSizeRot.xy;
|
let size = input.billboardSizeRot.xy;
|
||||||
let sinCos = input.billboardSizeRot.zw;
|
let sinCos = input.billboardSizeRot.zw;
|
||||||
|
|
|
||||||
|
|
@ -25,7 +25,7 @@ namespace Nz
|
||||||
m_shaderModule = moduleResolver.Resolve(moduleName);
|
m_shaderModule = moduleResolver.Resolve(moduleName);
|
||||||
NazaraAssert(m_shaderModule, "invalid shader module");
|
NazaraAssert(m_shaderModule, "invalid shader module");
|
||||||
|
|
||||||
Validate(*m_shaderModule);
|
m_shaderModule = Validate(*m_shaderModule, &m_optionIndexByName);
|
||||||
|
|
||||||
m_onShaderModuleUpdated.Connect(moduleResolver.OnModuleUpdated, [this, name = std::move(moduleName)](ShaderModuleResolver* resolver, const std::string& updatedModuleName)
|
m_onShaderModuleUpdated.Connect(moduleResolver.OnModuleUpdated, [this, name = std::move(moduleName)](ShaderModuleResolver* resolver, const std::string& updatedModuleName)
|
||||||
{
|
{
|
||||||
|
|
@ -41,8 +41,7 @@ namespace Nz
|
||||||
|
|
||||||
try
|
try
|
||||||
{
|
{
|
||||||
// FIXME: Validate is destructive, in case of failure it can invalidate the shader
|
m_shaderModule = Validate(*newShaderModule, &m_optionIndexByName);
|
||||||
Validate(*newShaderModule);
|
|
||||||
}
|
}
|
||||||
catch (const std::exception& e)
|
catch (const std::exception& e)
|
||||||
{
|
{
|
||||||
|
|
@ -50,8 +49,6 @@ namespace Nz
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
m_shaderModule = std::move(newShaderModule);
|
|
||||||
|
|
||||||
// Clear cache
|
// Clear cache
|
||||||
m_combinations.clear();
|
m_combinations.clear();
|
||||||
|
|
||||||
|
|
@ -65,7 +62,7 @@ namespace Nz
|
||||||
{
|
{
|
||||||
NazaraAssert(m_shaderModule, "invalid shader module");
|
NazaraAssert(m_shaderModule, "invalid shader module");
|
||||||
|
|
||||||
Validate(*m_shaderModule);
|
Validate(*m_shaderModule, &m_optionIndexByName);
|
||||||
}
|
}
|
||||||
|
|
||||||
const std::shared_ptr<ShaderModule>& UberShader::Get(const Config& config)
|
const std::shared_ptr<ShaderModule>& UberShader::Get(const Config& config)
|
||||||
|
|
@ -85,13 +82,17 @@ namespace Nz
|
||||||
return it->second;
|
return it->second;
|
||||||
}
|
}
|
||||||
|
|
||||||
void UberShader::Validate(ShaderAst::Module& module)
|
ShaderAst::ModulePtr UberShader::Validate(const ShaderAst::Module& module, std::unordered_map<std::string, Option>* options)
|
||||||
{
|
{
|
||||||
NazaraAssert(m_shaderStages != 0, "there must be at least one shader stage");
|
NazaraAssert(m_shaderStages != 0, "there must be at least one shader stage");
|
||||||
|
assert(options);
|
||||||
|
|
||||||
//TODO: Try to partially sanitize shader?
|
// Try to partially sanitize shader
|
||||||
|
|
||||||
std::size_t optionCount = 0;
|
ShaderAst::SanitizeVisitor::Options sanitizeOptions;
|
||||||
|
sanitizeOptions.allowPartialSanitization = true;
|
||||||
|
|
||||||
|
ShaderAst::ModulePtr sanitizedModule = ShaderAst::Sanitize(module, sanitizeOptions);
|
||||||
|
|
||||||
ShaderStageTypeFlags supportedStageType;
|
ShaderStageTypeFlags supportedStageType;
|
||||||
|
|
||||||
|
|
@ -101,21 +102,24 @@ namespace Nz
|
||||||
supportedStageType |= stageType;
|
supportedStageType |= stageType;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
std::unordered_map<std::string, Option> optionByName;
|
||||||
callbacks.onOptionDeclaration = [&](const ShaderAst::DeclareOptionStatement& option)
|
callbacks.onOptionDeclaration = [&](const ShaderAst::DeclareOptionStatement& option)
|
||||||
{
|
{
|
||||||
//TODO: Check optionType
|
//TODO: Check optionType
|
||||||
|
|
||||||
m_optionIndexByName[option.optName] = Option{
|
optionByName[option.optName] = Option{
|
||||||
CRC32(option.optName)
|
CRC32(option.optName)
|
||||||
};
|
};
|
||||||
|
|
||||||
optionCount++;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
ShaderAst::AstReflect reflect;
|
ShaderAst::AstReflect reflect;
|
||||||
reflect.Reflect(*module.rootNode, callbacks);
|
reflect.Reflect(*sanitizedModule->rootNode, callbacks);
|
||||||
|
|
||||||
if ((m_shaderStages & supportedStageType) != m_shaderStages)
|
if ((m_shaderStages & supportedStageType) != m_shaderStages)
|
||||||
throw std::runtime_error("shader doesn't support all required shader stages");
|
throw std::runtime_error("shader doesn't support all required shader stages");
|
||||||
|
|
||||||
|
*options = std::move(optionByName);
|
||||||
|
|
||||||
|
return sanitizedModule;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -481,6 +481,16 @@ namespace Nz::ShaderAst
|
||||||
return clone;
|
return clone;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ExpressionPtr AstCloner::Clone(TypeExpression& node)
|
||||||
|
{
|
||||||
|
auto clone = std::make_unique<TypeExpression>();
|
||||||
|
clone->typeId = node.typeId;
|
||||||
|
|
||||||
|
clone->cachedExpressionType = node.cachedExpressionType;
|
||||||
|
|
||||||
|
return clone;
|
||||||
|
}
|
||||||
|
|
||||||
ExpressionPtr AstCloner::Clone(VariableValueExpression& node)
|
ExpressionPtr AstCloner::Clone(VariableValueExpression& node)
|
||||||
{
|
{
|
||||||
auto clone = std::make_unique<VariableValueExpression>();
|
auto clone = std::make_unique<VariableValueExpression>();
|
||||||
|
|
|
||||||
|
|
@ -862,7 +862,7 @@ namespace Nz::ShaderAst
|
||||||
|
|
||||||
const auto& constantExpr = static_cast<ConstantValueExpression&>(*expressions[i]);
|
const auto& constantExpr = static_cast<ConstantValueExpression&>(*expressions[i]);
|
||||||
|
|
||||||
if (!constantValues.empty() && GetExpressionType(constantValues.front()) != GetExpressionType(constantExpr.value))
|
if (!constantValues.empty() && GetConstantType(constantValues.front()) != GetConstantType(constantExpr.value))
|
||||||
{
|
{
|
||||||
// Unhandled case, all cast parameters are expected to be of the same type
|
// Unhandled case, all cast parameters are expected to be of the same type
|
||||||
constantValues.clear();
|
constantValues.clear();
|
||||||
|
|
@ -940,16 +940,24 @@ namespace Nz::ShaderAst
|
||||||
std::vector<BranchStatement::ConditionalStatement> statements;
|
std::vector<BranchStatement::ConditionalStatement> statements;
|
||||||
StatementPtr elseStatement;
|
StatementPtr elseStatement;
|
||||||
|
|
||||||
|
bool continuePropagation = true;
|
||||||
for (auto& condStatement : node.condStatements)
|
for (auto& condStatement : node.condStatements)
|
||||||
{
|
{
|
||||||
auto cond = CloneExpression(condStatement.condition);
|
auto cond = CloneExpression(condStatement.condition);
|
||||||
|
|
||||||
if (cond->GetType() == NodeType::ConstantValueExpression)
|
if (continuePropagation && cond->GetType() == NodeType::ConstantValueExpression)
|
||||||
{
|
{
|
||||||
auto& constant = static_cast<ConstantValueExpression&>(*cond);
|
auto& constant = static_cast<ConstantValueExpression&>(*cond);
|
||||||
|
|
||||||
const ExpressionType& constantType = GetExpressionType(constant);
|
const ExpressionType* constantType = GetExpressionType(constant);
|
||||||
if (!IsPrimitiveType(constantType) || std::get<PrimitiveType>(constantType) != PrimitiveType::Boolean)
|
if (!constantType)
|
||||||
|
{
|
||||||
|
// unresolved type, can't continue propagating this branch
|
||||||
|
continuePropagation = false;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!IsPrimitiveType(*constantType) || std::get<PrimitiveType>(*constantType) != PrimitiveType::Boolean)
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
bool cValue = std::get<bool>(constant.value);
|
bool cValue = std::get<bool>(constant.value);
|
||||||
|
|
@ -1017,8 +1025,12 @@ namespace Nz::ShaderAst
|
||||||
if (!m_options.constantQueryCallback)
|
if (!m_options.constantQueryCallback)
|
||||||
return AstCloner::Clone(node);
|
return AstCloner::Clone(node);
|
||||||
|
|
||||||
auto constant = ShaderBuilder::Constant(m_options.constantQueryCallback(node.constantId));
|
const ConstantValue* constantValue = m_options.constantQueryCallback(node.constantId);
|
||||||
constant->cachedExpressionType = GetExpressionType(constant->value);
|
if (!constantValue)
|
||||||
|
return AstCloner::Clone(node);
|
||||||
|
|
||||||
|
auto constant = ShaderBuilder::Constant(*constantValue);
|
||||||
|
constant->cachedExpressionType = GetConstantType(constant->value);
|
||||||
|
|
||||||
return constant;
|
return constant;
|
||||||
}
|
}
|
||||||
|
|
@ -1155,7 +1167,7 @@ namespace Nz::ShaderAst
|
||||||
}, lhs.value);
|
}, lhs.value);
|
||||||
|
|
||||||
if (optimized)
|
if (optimized)
|
||||||
optimized->cachedExpressionType = GetExpressionType(optimized->value);
|
optimized->cachedExpressionType = GetConstantType(optimized->value);
|
||||||
|
|
||||||
return optimized;
|
return optimized;
|
||||||
}
|
}
|
||||||
|
|
@ -1221,7 +1233,7 @@ namespace Nz::ShaderAst
|
||||||
}, operand.value);
|
}, operand.value);
|
||||||
|
|
||||||
if (optimized)
|
if (optimized)
|
||||||
optimized->cachedExpressionType = GetExpressionType(optimized->value);
|
optimized->cachedExpressionType = GetConstantType(optimized->value);
|
||||||
|
|
||||||
return optimized;
|
return optimized;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -109,6 +109,11 @@ namespace Nz::ShaderAst
|
||||||
node.expression->Visit(*this);
|
node.expression->Visit(*this);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void AstRecursiveVisitor::Visit(TypeExpression& node)
|
||||||
|
{
|
||||||
|
/* Nothing to do */
|
||||||
|
}
|
||||||
|
|
||||||
void AstRecursiveVisitor::Visit(VariableValueExpression& /*node*/)
|
void AstRecursiveVisitor::Visit(VariableValueExpression& /*node*/)
|
||||||
{
|
{
|
||||||
/* Nothing to do */
|
/* Nothing to do */
|
||||||
|
|
|
||||||
|
|
@ -174,6 +174,11 @@ namespace Nz::ShaderAst
|
||||||
SizeT(node.structTypeId);
|
SizeT(node.structTypeId);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void AstSerializerBase::Serialize(TypeExpression& node)
|
||||||
|
{
|
||||||
|
SizeT(node.typeId);
|
||||||
|
}
|
||||||
|
|
||||||
void AstSerializerBase::Serialize(FunctionExpression& node)
|
void AstSerializerBase::Serialize(FunctionExpression& node)
|
||||||
{
|
{
|
||||||
SizeT(node.funcId);
|
SizeT(node.funcId);
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@
|
||||||
// For conditions of distribution and use, see copyright notice in Config.hpp
|
// For conditions of distribution and use, see copyright notice in Config.hpp
|
||||||
|
|
||||||
#include <Nazara/Shader/Ast/AstUtils.hpp>
|
#include <Nazara/Shader/Ast/AstUtils.hpp>
|
||||||
|
#include <cassert>
|
||||||
#include <Nazara/Shader/Debug.hpp>
|
#include <Nazara/Shader/Debug.hpp>
|
||||||
|
|
||||||
namespace Nz::ShaderAst
|
namespace Nz::ShaderAst
|
||||||
|
|
@ -104,7 +105,10 @@ namespace Nz::ShaderAst
|
||||||
|
|
||||||
void ShaderAstValueCategory::Visit(SwizzleExpression& node)
|
void ShaderAstValueCategory::Visit(SwizzleExpression& node)
|
||||||
{
|
{
|
||||||
if (IsPrimitiveType(GetExpressionType(node)) && node.componentCount > 1)
|
const ExpressionType* exprType = GetExpressionType(node);
|
||||||
|
assert(exprType);
|
||||||
|
|
||||||
|
if (IsPrimitiveType(*exprType) && node.componentCount > 1)
|
||||||
// Swizzling more than a component on a primitive produces a rvalue (a.xxxx cannot be assigned)
|
// Swizzling more than a component on a primitive produces a rvalue (a.xxxx cannot be assigned)
|
||||||
m_expressionCategory = ExpressionCategory::RValue;
|
m_expressionCategory = ExpressionCategory::RValue;
|
||||||
else
|
else
|
||||||
|
|
@ -133,6 +137,11 @@ namespace Nz::ShaderAst
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ShaderAstValueCategory::Visit(TypeExpression& /*node*/)
|
||||||
|
{
|
||||||
|
m_expressionCategory = ExpressionCategory::LValue;
|
||||||
|
}
|
||||||
|
|
||||||
void ShaderAstValueCategory::Visit(VariableValueExpression& /*node*/)
|
void ShaderAstValueCategory::Visit(VariableValueExpression& /*node*/)
|
||||||
{
|
{
|
||||||
m_expressionCategory = ExpressionCategory::LValue;
|
m_expressionCategory = ExpressionCategory::LValue;
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,7 @@
|
||||||
|
|
||||||
namespace Nz::ShaderAst
|
namespace Nz::ShaderAst
|
||||||
{
|
{
|
||||||
ExpressionType GetExpressionType(const ConstantValue& constant)
|
ExpressionType GetConstantType(const ConstantValue& constant)
|
||||||
{
|
{
|
||||||
return std::visit([&](auto&& arg) -> ShaderAst::ExpressionType
|
return std::visit([&](auto&& arg) -> ShaderAst::ExpressionType
|
||||||
{
|
{
|
||||||
|
|
|
||||||
|
|
@ -15,10 +15,11 @@ namespace Nz::ShaderAst
|
||||||
|
|
||||||
void DependencyCheckerVisitor::Visit(CallFunctionExpression& node)
|
void DependencyCheckerVisitor::Visit(CallFunctionExpression& node)
|
||||||
{
|
{
|
||||||
const auto& targetFuncType = GetExpressionType(*node.targetFunction);
|
const ExpressionType* targetFuncType = GetExpressionType(*node.targetFunction);
|
||||||
assert(std::holds_alternative<FunctionType>(targetFuncType));
|
assert(targetFuncType);
|
||||||
|
assert(std::holds_alternative<FunctionType>(*targetFuncType));
|
||||||
|
|
||||||
const auto& funcType = std::get<FunctionType>(targetFuncType);
|
const auto& funcType = std::get<FunctionType>(*targetFuncType);
|
||||||
|
|
||||||
assert(m_currentFunctionIndex);
|
assert(m_currentFunctionIndex);
|
||||||
if (m_currentVariableDeclIndex)
|
if (m_currentVariableDeclIndex)
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -36,7 +36,7 @@ namespace Nz
|
||||||
AstRecursiveVisitor::Visit(node);
|
AstRecursiveVisitor::Visit(node);
|
||||||
|
|
||||||
assert(currentFunction);
|
assert(currentFunction);
|
||||||
currentFunction->calledFunctions.UnboundedSet(std::get<ShaderAst::FunctionType>(GetExpressionType(*node.targetFunction)).funcIndex);
|
currentFunction->calledFunctions.UnboundedSet(std::get<ShaderAst::FunctionType>(*GetExpressionType(*node.targetFunction)).funcIndex);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Visit(ShaderAst::ConditionalExpression& /*node*/) override
|
void Visit(ShaderAst::ConditionalExpression& /*node*/) override
|
||||||
|
|
@ -307,7 +307,7 @@ namespace Nz
|
||||||
Append(type.GetResultingValue());
|
Append(type.GetResultingValue());
|
||||||
}
|
}
|
||||||
|
|
||||||
void GlslWriter::Append(const ShaderAst::FunctionType& functionType)
|
void GlslWriter::Append(const ShaderAst::FunctionType& /*functionType*/)
|
||||||
{
|
{
|
||||||
throw std::runtime_error("unexpected FunctionType");
|
throw std::runtime_error("unexpected FunctionType");
|
||||||
}
|
}
|
||||||
|
|
@ -829,8 +829,9 @@ namespace Nz
|
||||||
{
|
{
|
||||||
Visit(node.expr, true);
|
Visit(node.expr, true);
|
||||||
|
|
||||||
const ShaderAst::ExpressionType& exprType = GetExpressionType(*node.expr);
|
const ShaderAst::ExpressionType* exprType = GetExpressionType(*node.expr);
|
||||||
assert(IsStructType(exprType));
|
assert(exprType);
|
||||||
|
assert(IsStructType(*exprType));
|
||||||
|
|
||||||
for (const std::string& identifier : node.identifiers)
|
for (const std::string& identifier : node.identifiers)
|
||||||
Append(".", identifier);
|
Append(".", identifier);
|
||||||
|
|
@ -840,8 +841,9 @@ namespace Nz
|
||||||
{
|
{
|
||||||
Visit(node.expr, true);
|
Visit(node.expr, true);
|
||||||
|
|
||||||
const ShaderAst::ExpressionType& exprType = GetExpressionType(*node.expr);
|
const ShaderAst::ExpressionType* exprType = GetExpressionType(*node.expr);
|
||||||
assert(!IsStructType(exprType));
|
assert(exprType);
|
||||||
|
assert(!IsStructType(*exprType));
|
||||||
|
|
||||||
// Array access
|
// Array access
|
||||||
assert(node.indices.size() == 1);
|
assert(node.indices.size() == 1);
|
||||||
|
|
@ -1326,9 +1328,10 @@ namespace Nz
|
||||||
{
|
{
|
||||||
assert(node.returnExpr);
|
assert(node.returnExpr);
|
||||||
|
|
||||||
const ShaderAst::ExpressionType& returnType = GetExpressionType(*node.returnExpr);
|
const ShaderAst::ExpressionType* returnType = GetExpressionType(*node.returnExpr);
|
||||||
assert(IsStructType(returnType));
|
assert(returnType);
|
||||||
std::size_t structIndex = std::get<ShaderAst::StructType>(returnType).structIndex;
|
assert(IsStructType(*returnType));
|
||||||
|
std::size_t structIndex = std::get<ShaderAst::StructType>(*returnType).structIndex;
|
||||||
const auto& structData = Retrieve(m_currentState->structs, structIndex);
|
const auto& structData = Retrieve(m_currentState->structs, structIndex);
|
||||||
|
|
||||||
std::string outputStructVarName;
|
std::string outputStructVarName;
|
||||||
|
|
|
||||||
|
|
@ -182,7 +182,7 @@ namespace Nz
|
||||||
type.GetExpression()->Visit(*this);
|
type.GetExpression()->Visit(*this);
|
||||||
}
|
}
|
||||||
|
|
||||||
void LangWriter::Append(const ShaderAst::FunctionType& functionType)
|
void LangWriter::Append(const ShaderAst::FunctionType& /*functionType*/)
|
||||||
{
|
{
|
||||||
throw std::runtime_error("unexpected function type");
|
throw std::runtime_error("unexpected function type");
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -263,7 +263,7 @@ namespace Nz::ShaderLang
|
||||||
throw AttributeError{ "attribute " + std::string("nzsl_version") + " expect a single string parameter" };
|
throw AttributeError{ "attribute " + std::string("nzsl_version") + " expect a single string parameter" };
|
||||||
|
|
||||||
auto& constantValue = SafeCast<ShaderAst::ConstantValueExpression&>(*expr);
|
auto& constantValue = SafeCast<ShaderAst::ConstantValueExpression&>(*expr);
|
||||||
if (ShaderAst::GetExpressionType(constantValue.value) != ShaderAst::ExpressionType{ ShaderAst::PrimitiveType::String })
|
if (ShaderAst::GetConstantType(constantValue.value) != ShaderAst::ExpressionType{ ShaderAst::PrimitiveType::String })
|
||||||
throw AttributeError{ "attribute " + std::string("nzsl_version") + " expect a single string parameter" };
|
throw AttributeError{ "attribute " + std::string("nzsl_version") + " expect a single string parameter" };
|
||||||
|
|
||||||
const std::string& versionStr = std::get<std::string>(constantValue.value);
|
const std::string& versionStr = std::get<std::string>(constantValue.value);
|
||||||
|
|
@ -302,7 +302,7 @@ namespace Nz::ShaderLang
|
||||||
throw AttributeError{ "attribute " + std::string("uuid") + " expect a single string parameter" };
|
throw AttributeError{ "attribute " + std::string("uuid") + " expect a single string parameter" };
|
||||||
|
|
||||||
auto& constantValue = SafeCast<ShaderAst::ConstantValueExpression&>(*expr);
|
auto& constantValue = SafeCast<ShaderAst::ConstantValueExpression&>(*expr);
|
||||||
if (ShaderAst::GetExpressionType(constantValue.value) != ShaderAst::ExpressionType{ ShaderAst::PrimitiveType::String })
|
if (ShaderAst::GetConstantType(constantValue.value) != ShaderAst::ExpressionType{ ShaderAst::PrimitiveType::String })
|
||||||
throw AttributeError{ "attribute " + std::string("uuid") + " expect a single string parameter" };
|
throw AttributeError{ "attribute " + std::string("uuid") + " expect a single string parameter" };
|
||||||
|
|
||||||
const std::string& uuidStr = std::get<std::string>(constantValue.value);
|
const std::string& uuidStr = std::get<std::string>(constantValue.value);
|
||||||
|
|
|
||||||
|
|
@ -67,9 +67,9 @@ namespace Nz
|
||||||
throw std::runtime_error("unexpected type");
|
throw std::runtime_error("unexpected type");
|
||||||
};
|
};
|
||||||
|
|
||||||
const ShaderAst::ExpressionType& resultType = GetExpressionType(node);
|
const ShaderAst::ExpressionType& resultType = *GetExpressionType(node);
|
||||||
const ShaderAst::ExpressionType& leftType = GetExpressionType(*node.left);
|
const ShaderAst::ExpressionType& leftType = *GetExpressionType(*node.left);
|
||||||
const ShaderAst::ExpressionType& rightType = GetExpressionType(*node.right);
|
const ShaderAst::ExpressionType& rightType = *GetExpressionType(*node.right);
|
||||||
|
|
||||||
ShaderAst::PrimitiveType leftTypeBase = RetrieveBaseType(leftType);
|
ShaderAst::PrimitiveType leftTypeBase = RetrieveBaseType(leftType);
|
||||||
//ShaderAst::PrimitiveType rightTypeBase = RetrieveBaseType(rightType);
|
//ShaderAst::PrimitiveType rightTypeBase = RetrieveBaseType(rightType);
|
||||||
|
|
@ -405,7 +405,7 @@ namespace Nz
|
||||||
|
|
||||||
void SpirvAstVisitor::Visit(ShaderAst::CallFunctionExpression& node)
|
void SpirvAstVisitor::Visit(ShaderAst::CallFunctionExpression& node)
|
||||||
{
|
{
|
||||||
std::size_t functionIndex = std::get<ShaderAst::FunctionType>(GetExpressionType(*node.targetFunction)).funcIndex;
|
std::size_t functionIndex = std::get<ShaderAst::FunctionType>(*GetExpressionType(*node.targetFunction)).funcIndex;
|
||||||
|
|
||||||
UInt32 funcId = 0;
|
UInt32 funcId = 0;
|
||||||
for (const auto& [funcIndex, func] : m_funcData)
|
for (const auto& [funcIndex, func] : m_funcData)
|
||||||
|
|
@ -434,7 +434,7 @@ namespace Nz
|
||||||
UInt32 resultId = AllocateResultId();
|
UInt32 resultId = AllocateResultId();
|
||||||
m_currentBlock->AppendVariadic(SpirvOp::OpFunctionCall, [&](auto&& appender)
|
m_currentBlock->AppendVariadic(SpirvOp::OpFunctionCall, [&](auto&& appender)
|
||||||
{
|
{
|
||||||
appender(m_writer.GetTypeId(ShaderAst::GetExpressionType(node)));
|
appender(m_writer.GetTypeId(*ShaderAst::GetExpressionType(node)));
|
||||||
appender(resultId);
|
appender(resultId);
|
||||||
appender(funcId);
|
appender(funcId);
|
||||||
|
|
||||||
|
|
@ -718,9 +718,11 @@ namespace Nz
|
||||||
{
|
{
|
||||||
UInt32 glslInstructionSet = m_writer.GetExtendedInstructionSet("GLSL.std.450");
|
UInt32 glslInstructionSet = m_writer.GetExtendedInstructionSet("GLSL.std.450");
|
||||||
|
|
||||||
const ShaderAst::ExpressionType& parameterType = GetExpressionType(*node.parameters[0]);
|
const ShaderAst::ExpressionType* parameterType = GetExpressionType(*node.parameters[0]);
|
||||||
assert(IsVectorType(parameterType));
|
assert(parameterType);
|
||||||
UInt32 typeId = m_writer.GetTypeId(parameterType);
|
assert(IsVectorType(*parameterType));
|
||||||
|
|
||||||
|
UInt32 typeId = m_writer.GetTypeId(*parameterType);
|
||||||
|
|
||||||
UInt32 firstParam = EvaluateExpression(node.parameters[0]);
|
UInt32 firstParam = EvaluateExpression(node.parameters[0]);
|
||||||
UInt32 secondParam = EvaluateExpression(node.parameters[1]);
|
UInt32 secondParam = EvaluateExpression(node.parameters[1]);
|
||||||
|
|
@ -733,10 +735,11 @@ namespace Nz
|
||||||
|
|
||||||
case ShaderAst::IntrinsicType::DotProduct:
|
case ShaderAst::IntrinsicType::DotProduct:
|
||||||
{
|
{
|
||||||
const ShaderAst::ExpressionType& vecExprType = GetExpressionType(*node.parameters[0]);
|
const ShaderAst::ExpressionType* vecExprType = GetExpressionType(*node.parameters[0]);
|
||||||
assert(IsVectorType(vecExprType));
|
assert(vecExprType);
|
||||||
|
assert(IsVectorType(*vecExprType));
|
||||||
|
|
||||||
const ShaderAst::VectorType& vecType = std::get<ShaderAst::VectorType>(vecExprType);
|
const ShaderAst::VectorType& vecType = std::get<ShaderAst::VectorType>(*vecExprType);
|
||||||
|
|
||||||
UInt32 typeId = m_writer.GetTypeId(vecType.type);
|
UInt32 typeId = m_writer.GetTypeId(vecType.type);
|
||||||
|
|
||||||
|
|
@ -754,9 +757,10 @@ namespace Nz
|
||||||
{
|
{
|
||||||
UInt32 glslInstructionSet = m_writer.GetExtendedInstructionSet("GLSL.std.450");
|
UInt32 glslInstructionSet = m_writer.GetExtendedInstructionSet("GLSL.std.450");
|
||||||
|
|
||||||
const ShaderAst::ExpressionType& parameterType = GetExpressionType(*node.parameters[0]);
|
const ShaderAst::ExpressionType* parameterType = GetExpressionType(*node.parameters[0]);
|
||||||
assert(IsPrimitiveType(parameterType) || IsVectorType(parameterType));
|
assert(parameterType);
|
||||||
UInt32 typeId = m_writer.GetTypeId(parameterType);
|
assert(IsPrimitiveType(*parameterType) || IsVectorType(*parameterType));
|
||||||
|
UInt32 typeId = m_writer.GetTypeId(*parameterType);
|
||||||
|
|
||||||
UInt32 param = EvaluateExpression(node.parameters[0]);
|
UInt32 param = EvaluateExpression(node.parameters[0]);
|
||||||
UInt32 resultId = m_writer.AllocateResultId();
|
UInt32 resultId = m_writer.AllocateResultId();
|
||||||
|
|
@ -770,10 +774,11 @@ namespace Nz
|
||||||
{
|
{
|
||||||
UInt32 glslInstructionSet = m_writer.GetExtendedInstructionSet("GLSL.std.450");
|
UInt32 glslInstructionSet = m_writer.GetExtendedInstructionSet("GLSL.std.450");
|
||||||
|
|
||||||
const ShaderAst::ExpressionType& vecExprType = GetExpressionType(*node.parameters[0]);
|
const ShaderAst::ExpressionType* vecExprType = GetExpressionType(*node.parameters[0]);
|
||||||
assert(IsVectorType(vecExprType));
|
assert(vecExprType);
|
||||||
|
assert(IsVectorType(*vecExprType));
|
||||||
|
|
||||||
const ShaderAst::VectorType& vecType = std::get<ShaderAst::VectorType>(vecExprType);
|
const ShaderAst::VectorType& vecType = std::get<ShaderAst::VectorType>(*vecExprType);
|
||||||
UInt32 typeId = m_writer.GetTypeId(vecType.type);
|
UInt32 typeId = m_writer.GetTypeId(vecType.type);
|
||||||
|
|
||||||
UInt32 vec = EvaluateExpression(node.parameters[0]);
|
UInt32 vec = EvaluateExpression(node.parameters[0]);
|
||||||
|
|
@ -790,15 +795,16 @@ namespace Nz
|
||||||
{
|
{
|
||||||
UInt32 glslInstructionSet = m_writer.GetExtendedInstructionSet("GLSL.std.450");
|
UInt32 glslInstructionSet = m_writer.GetExtendedInstructionSet("GLSL.std.450");
|
||||||
|
|
||||||
const ShaderAst::ExpressionType& parameterType = GetExpressionType(*node.parameters[0]);
|
const ShaderAst::ExpressionType* parameterType = GetExpressionType(*node.parameters[0]);
|
||||||
assert(IsPrimitiveType(parameterType) || IsVectorType(parameterType));
|
assert(parameterType);
|
||||||
UInt32 typeId = m_writer.GetTypeId(parameterType);
|
assert(IsPrimitiveType(*parameterType) || IsVectorType(*parameterType));
|
||||||
|
UInt32 typeId = m_writer.GetTypeId(*parameterType);
|
||||||
|
|
||||||
ShaderAst::PrimitiveType basicType;
|
ShaderAst::PrimitiveType basicType;
|
||||||
if (IsPrimitiveType(parameterType))
|
if (IsPrimitiveType(*parameterType))
|
||||||
basicType = std::get<ShaderAst::PrimitiveType>(parameterType);
|
basicType = std::get<ShaderAst::PrimitiveType>(*parameterType);
|
||||||
else if (IsVectorType(parameterType))
|
else if (IsVectorType(*parameterType))
|
||||||
basicType = std::get<ShaderAst::VectorType>(parameterType).type;
|
basicType = std::get<ShaderAst::VectorType>(*parameterType).type;
|
||||||
else
|
else
|
||||||
throw std::runtime_error("unexpected expression type");
|
throw std::runtime_error("unexpected expression type");
|
||||||
|
|
||||||
|
|
@ -837,10 +843,11 @@ namespace Nz
|
||||||
{
|
{
|
||||||
UInt32 glslInstructionSet = m_writer.GetExtendedInstructionSet("GLSL.std.450");
|
UInt32 glslInstructionSet = m_writer.GetExtendedInstructionSet("GLSL.std.450");
|
||||||
|
|
||||||
const ShaderAst::ExpressionType& vecExprType = GetExpressionType(*node.parameters[0]);
|
const ShaderAst::ExpressionType* vecExprType = GetExpressionType(*node.parameters[0]);
|
||||||
assert(IsVectorType(vecExprType));
|
assert(vecExprType);
|
||||||
|
assert(IsVectorType(*vecExprType));
|
||||||
|
|
||||||
const ShaderAst::VectorType& vecType = std::get<ShaderAst::VectorType>(vecExprType);
|
const ShaderAst::VectorType& vecType = std::get<ShaderAst::VectorType>(*vecExprType);
|
||||||
UInt32 typeId = m_writer.GetTypeId(vecType);
|
UInt32 typeId = m_writer.GetTypeId(vecType);
|
||||||
|
|
||||||
UInt32 vec = EvaluateExpression(node.parameters[0]);
|
UInt32 vec = EvaluateExpression(node.parameters[0]);
|
||||||
|
|
@ -856,9 +863,10 @@ namespace Nz
|
||||||
{
|
{
|
||||||
UInt32 glslInstructionSet = m_writer.GetExtendedInstructionSet("GLSL.std.450");
|
UInt32 glslInstructionSet = m_writer.GetExtendedInstructionSet("GLSL.std.450");
|
||||||
|
|
||||||
const ShaderAst::ExpressionType& parameterType = GetExpressionType(*node.parameters[0]);
|
const ShaderAst::ExpressionType* parameterType = GetExpressionType(*node.parameters[0]);
|
||||||
assert(IsPrimitiveType(parameterType) || IsVectorType(parameterType));
|
assert(parameterType);
|
||||||
UInt32 typeId = m_writer.GetTypeId(parameterType);
|
assert(IsPrimitiveType(*parameterType) || IsVectorType(*parameterType));
|
||||||
|
UInt32 typeId = m_writer.GetTypeId(*parameterType);
|
||||||
|
|
||||||
UInt32 firstParam = EvaluateExpression(node.parameters[0]);
|
UInt32 firstParam = EvaluateExpression(node.parameters[0]);
|
||||||
UInt32 secondParam = EvaluateExpression(node.parameters[1]);
|
UInt32 secondParam = EvaluateExpression(node.parameters[1]);
|
||||||
|
|
@ -873,9 +881,10 @@ namespace Nz
|
||||||
{
|
{
|
||||||
UInt32 glslInstructionSet = m_writer.GetExtendedInstructionSet("GLSL.std.450");
|
UInt32 glslInstructionSet = m_writer.GetExtendedInstructionSet("GLSL.std.450");
|
||||||
|
|
||||||
const ShaderAst::ExpressionType& parameterType = GetExpressionType(*node.parameters[0]);
|
const ShaderAst::ExpressionType* parameterType = GetExpressionType(*node.parameters[0]);
|
||||||
assert(IsVectorType(parameterType));
|
assert(parameterType);
|
||||||
UInt32 typeId = m_writer.GetTypeId(parameterType);
|
assert(IsVectorType(*parameterType));
|
||||||
|
UInt32 typeId = m_writer.GetTypeId(*parameterType);
|
||||||
|
|
||||||
UInt32 firstParam = EvaluateExpression(node.parameters[0]);
|
UInt32 firstParam = EvaluateExpression(node.parameters[0]);
|
||||||
UInt32 secondParam = EvaluateExpression(node.parameters[1]);
|
UInt32 secondParam = EvaluateExpression(node.parameters[1]);
|
||||||
|
|
@ -951,20 +960,22 @@ namespace Nz
|
||||||
|
|
||||||
void SpirvAstVisitor::Visit(ShaderAst::SwizzleExpression& node)
|
void SpirvAstVisitor::Visit(ShaderAst::SwizzleExpression& node)
|
||||||
{
|
{
|
||||||
const ShaderAst::ExpressionType& swizzledExpressionType = GetExpressionType(*node.expression);
|
const ShaderAst::ExpressionType* swizzledExpressionType = GetExpressionType(*node.expression);
|
||||||
|
assert(swizzledExpressionType);
|
||||||
|
|
||||||
UInt32 exprResultId = EvaluateExpression(node.expression);
|
UInt32 exprResultId = EvaluateExpression(node.expression);
|
||||||
|
|
||||||
const ShaderAst::ExpressionType& targetExprType = GetExpressionType(node);
|
const ShaderAst::ExpressionType* targetExprType = GetExpressionType(node);
|
||||||
|
assert(targetExprType);
|
||||||
|
|
||||||
if (node.componentCount > 1)
|
if (node.componentCount > 1)
|
||||||
{
|
{
|
||||||
assert(IsVectorType(targetExprType));
|
assert(IsVectorType(*targetExprType));
|
||||||
|
|
||||||
const ShaderAst::VectorType& targetType = std::get<ShaderAst::VectorType>(targetExprType);
|
const ShaderAst::VectorType& targetType = std::get<ShaderAst::VectorType>(*targetExprType);
|
||||||
|
|
||||||
UInt32 resultId = m_writer.AllocateResultId();
|
UInt32 resultId = m_writer.AllocateResultId();
|
||||||
if (IsVectorType(swizzledExpressionType))
|
if (IsVectorType(*swizzledExpressionType))
|
||||||
{
|
{
|
||||||
// Swizzling a vector is implemented via OpVectorShuffle using the same vector twice as operands
|
// Swizzling a vector is implemented via OpVectorShuffle using the same vector twice as operands
|
||||||
m_currentBlock->AppendVariadic(SpirvOp::OpVectorShuffle, [&](const auto& appender)
|
m_currentBlock->AppendVariadic(SpirvOp::OpVectorShuffle, [&](const auto& appender)
|
||||||
|
|
@ -980,7 +991,7 @@ namespace Nz
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
assert(IsPrimitiveType(swizzledExpressionType));
|
assert(IsPrimitiveType(*swizzledExpressionType));
|
||||||
|
|
||||||
// Swizzling a primitive to a vector (a.xxx) can be implemented using OpCompositeConstruct
|
// Swizzling a primitive to a vector (a.xxx) can be implemented using OpCompositeConstruct
|
||||||
m_currentBlock->AppendVariadic(SpirvOp::OpCompositeConstruct, [&](const auto& appender)
|
m_currentBlock->AppendVariadic(SpirvOp::OpCompositeConstruct, [&](const auto& appender)
|
||||||
|
|
@ -995,10 +1006,10 @@ namespace Nz
|
||||||
|
|
||||||
PushResultId(resultId);
|
PushResultId(resultId);
|
||||||
}
|
}
|
||||||
else if (IsVectorType(swizzledExpressionType))
|
else if (IsVectorType(*swizzledExpressionType))
|
||||||
{
|
{
|
||||||
assert(IsPrimitiveType(targetExprType));
|
assert(IsPrimitiveType(*targetExprType));
|
||||||
ShaderAst::PrimitiveType targetType = std::get<ShaderAst::PrimitiveType>(targetExprType);
|
ShaderAst::PrimitiveType targetType = std::get<ShaderAst::PrimitiveType>(*targetExprType);
|
||||||
|
|
||||||
// Extract a single component from the vector
|
// Extract a single component from the vector
|
||||||
assert(node.componentCount == 1);
|
assert(node.componentCount == 1);
|
||||||
|
|
@ -1011,8 +1022,8 @@ namespace Nz
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
// Swizzling a primitive to itself (a.x for example), don't do anything
|
// Swizzling a primitive to itself (a.x for example), don't do anything
|
||||||
assert(IsPrimitiveType(swizzledExpressionType));
|
assert(IsPrimitiveType(*swizzledExpressionType));
|
||||||
assert(IsPrimitiveType(targetExprType));
|
assert(IsPrimitiveType(*targetExprType));
|
||||||
assert(node.componentCount == 1);
|
assert(node.componentCount == 1);
|
||||||
assert(node.components[0] == 0);
|
assert(node.components[0] == 0);
|
||||||
|
|
||||||
|
|
@ -1022,8 +1033,11 @@ namespace Nz
|
||||||
|
|
||||||
void SpirvAstVisitor::Visit(ShaderAst::UnaryExpression& node)
|
void SpirvAstVisitor::Visit(ShaderAst::UnaryExpression& node)
|
||||||
{
|
{
|
||||||
const ShaderAst::ExpressionType& resultType = GetExpressionType(node);
|
const ShaderAst::ExpressionType* resultType = GetExpressionType(node);
|
||||||
const ShaderAst::ExpressionType& exprType = GetExpressionType(*node.expression);
|
assert(resultType);
|
||||||
|
|
||||||
|
const ShaderAst::ExpressionType* exprType = GetExpressionType(*node.expression);
|
||||||
|
assert(exprType);
|
||||||
|
|
||||||
UInt32 operand = EvaluateExpression(node.expression);
|
UInt32 operand = EvaluateExpression(node.expression);
|
||||||
|
|
||||||
|
|
@ -1033,11 +1047,11 @@ namespace Nz
|
||||||
{
|
{
|
||||||
case ShaderAst::UnaryType::LogicalNot:
|
case ShaderAst::UnaryType::LogicalNot:
|
||||||
{
|
{
|
||||||
assert(IsPrimitiveType(exprType));
|
assert(IsPrimitiveType(*exprType));
|
||||||
assert(std::get<ShaderAst::PrimitiveType>(resultType) == ShaderAst::PrimitiveType::Boolean);
|
assert(std::get<ShaderAst::PrimitiveType>(*resultType) == ShaderAst::PrimitiveType::Boolean);
|
||||||
|
|
||||||
UInt32 resultId = m_writer.AllocateResultId();
|
UInt32 resultId = m_writer.AllocateResultId();
|
||||||
m_currentBlock->Append(SpirvOp::OpLogicalNot, m_writer.GetTypeId(resultType), resultId, operand);
|
m_currentBlock->Append(SpirvOp::OpLogicalNot, m_writer.GetTypeId(*resultType), resultId, operand);
|
||||||
|
|
||||||
return resultId;
|
return resultId;
|
||||||
}
|
}
|
||||||
|
|
@ -1045,10 +1059,10 @@ namespace Nz
|
||||||
case ShaderAst::UnaryType::Minus:
|
case ShaderAst::UnaryType::Minus:
|
||||||
{
|
{
|
||||||
ShaderAst::PrimitiveType basicType;
|
ShaderAst::PrimitiveType basicType;
|
||||||
if (IsPrimitiveType(exprType))
|
if (IsPrimitiveType(*exprType))
|
||||||
basicType = std::get<ShaderAst::PrimitiveType>(exprType);
|
basicType = std::get<ShaderAst::PrimitiveType>(*exprType);
|
||||||
else if (IsVectorType(exprType))
|
else if (IsVectorType(*exprType))
|
||||||
basicType = std::get<ShaderAst::VectorType>(exprType).type;
|
basicType = std::get<ShaderAst::VectorType>(*exprType).type;
|
||||||
else
|
else
|
||||||
throw std::runtime_error("unexpected expression type");
|
throw std::runtime_error("unexpected expression type");
|
||||||
|
|
||||||
|
|
@ -1057,12 +1071,12 @@ namespace Nz
|
||||||
switch (basicType)
|
switch (basicType)
|
||||||
{
|
{
|
||||||
case ShaderAst::PrimitiveType::Float32:
|
case ShaderAst::PrimitiveType::Float32:
|
||||||
m_currentBlock->Append(SpirvOp::OpFNegate, m_writer.GetTypeId(resultType), resultId, operand);
|
m_currentBlock->Append(SpirvOp::OpFNegate, m_writer.GetTypeId(*resultType), resultId, operand);
|
||||||
return resultId;
|
return resultId;
|
||||||
|
|
||||||
case ShaderAst::PrimitiveType::Int32:
|
case ShaderAst::PrimitiveType::Int32:
|
||||||
case ShaderAst::PrimitiveType::UInt32:
|
case ShaderAst::PrimitiveType::UInt32:
|
||||||
m_currentBlock->Append(SpirvOp::OpSNegate, m_writer.GetTypeId(resultType), resultId, operand);
|
m_currentBlock->Append(SpirvOp::OpSNegate, m_writer.GetTypeId(*resultType), resultId, operand);
|
||||||
return resultId;
|
return resultId;
|
||||||
|
|
||||||
default:
|
default:
|
||||||
|
|
|
||||||
|
|
@ -76,9 +76,10 @@ namespace Nz
|
||||||
{
|
{
|
||||||
node.expr->Visit(*this);
|
node.expr->Visit(*this);
|
||||||
|
|
||||||
const ShaderAst::ExpressionType& exprType = GetExpressionType(node);
|
const ShaderAst::ExpressionType* exprType = GetExpressionType(node);
|
||||||
|
assert(exprType);
|
||||||
|
|
||||||
UInt32 typeId = m_writer.GetTypeId(exprType);
|
UInt32 typeId = m_writer.GetTypeId(*exprType);
|
||||||
|
|
||||||
assert(node.indices.size() == 1);
|
assert(node.indices.size() == 1);
|
||||||
UInt32 indexId = m_visitor.EvaluateExpression(node.indices.front());
|
UInt32 indexId = m_visitor.EvaluateExpression(node.indices.front());
|
||||||
|
|
@ -88,7 +89,7 @@ namespace Nz
|
||||||
[&](const Pointer& pointer)
|
[&](const Pointer& pointer)
|
||||||
{
|
{
|
||||||
PointerChainAccess pointerChainAccess;
|
PointerChainAccess pointerChainAccess;
|
||||||
pointerChainAccess.exprType = &exprType;
|
pointerChainAccess.exprType = exprType;
|
||||||
pointerChainAccess.indices = { indexId };
|
pointerChainAccess.indices = { indexId };
|
||||||
pointerChainAccess.pointedTypeId = pointer.pointedTypeId;
|
pointerChainAccess.pointedTypeId = pointer.pointedTypeId;
|
||||||
pointerChainAccess.pointerId = pointer.pointerId;
|
pointerChainAccess.pointerId = pointer.pointerId;
|
||||||
|
|
@ -98,7 +99,7 @@ namespace Nz
|
||||||
},
|
},
|
||||||
[&](PointerChainAccess& pointerChainAccess)
|
[&](PointerChainAccess& pointerChainAccess)
|
||||||
{
|
{
|
||||||
pointerChainAccess.exprType = &exprType;
|
pointerChainAccess.exprType = exprType;
|
||||||
pointerChainAccess.indices.push_back(indexId);
|
pointerChainAccess.indices.push_back(indexId);
|
||||||
},
|
},
|
||||||
[&](const Value& value)
|
[&](const Value& value)
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,7 @@
|
||||||
#include <Nazara/Shader/SpirvAstVisitor.hpp>
|
#include <Nazara/Shader/SpirvAstVisitor.hpp>
|
||||||
#include <Nazara/Shader/SpirvBlock.hpp>
|
#include <Nazara/Shader/SpirvBlock.hpp>
|
||||||
#include <Nazara/Shader/SpirvWriter.hpp>
|
#include <Nazara/Shader/SpirvWriter.hpp>
|
||||||
|
#include <cassert>
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
#include <Nazara/Shader/Debug.hpp>
|
#include <Nazara/Shader/Debug.hpp>
|
||||||
|
|
||||||
|
|
@ -61,11 +62,12 @@ namespace Nz
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
const ShaderAst::ExpressionType& exprType = GetExpressionType(*node);
|
|
||||||
|
|
||||||
assert(swizzledPointer.componentCount == 1);
|
assert(swizzledPointer.componentCount == 1);
|
||||||
|
|
||||||
UInt32 pointerType = m_writer.RegisterPointerType(exprType, swizzledPointer.storage); //< FIXME
|
const ShaderAst::ExpressionType* exprType = GetExpressionType(*node);
|
||||||
|
assert(exprType);
|
||||||
|
|
||||||
|
UInt32 pointerType = m_writer.RegisterPointerType(*exprType, swizzledPointer.storage); //< FIXME
|
||||||
|
|
||||||
// Access chain
|
// Access chain
|
||||||
UInt32 indexId = m_writer.GetConstantId(SafeCast<Int32>(swizzledPointer.swizzleIndices[0]));
|
UInt32 indexId = m_writer.GetConstantId(SafeCast<Int32>(swizzledPointer.swizzleIndices[0]));
|
||||||
|
|
@ -86,14 +88,15 @@ namespace Nz
|
||||||
{
|
{
|
||||||
node.expr->Visit(*this);
|
node.expr->Visit(*this);
|
||||||
|
|
||||||
const ShaderAst::ExpressionType& exprType = GetExpressionType(node);
|
const ShaderAst::ExpressionType* exprType = GetExpressionType(node);
|
||||||
|
assert(exprType);
|
||||||
|
|
||||||
std::visit(Overloaded
|
std::visit(Overloaded
|
||||||
{
|
{
|
||||||
[&](const Pointer& pointer)
|
[&](const Pointer& pointer)
|
||||||
{
|
{
|
||||||
UInt32 resultId = m_visitor.AllocateResultId();
|
UInt32 resultId = m_visitor.AllocateResultId();
|
||||||
UInt32 pointerType = m_writer.RegisterPointerType(exprType, pointer.storage); //< FIXME
|
UInt32 pointerType = m_writer.RegisterPointerType(*exprType, pointer.storage); //< FIXME
|
||||||
|
|
||||||
assert(node.indices.size() == 1);
|
assert(node.indices.size() == 1);
|
||||||
UInt32 indexId = m_visitor.EvaluateExpression(node.indices.front());
|
UInt32 indexId = m_visitor.EvaluateExpression(node.indices.front());
|
||||||
|
|
@ -117,13 +120,14 @@ namespace Nz
|
||||||
{
|
{
|
||||||
[&](const Pointer& pointer)
|
[&](const Pointer& pointer)
|
||||||
{
|
{
|
||||||
const auto& expressionType = GetExpressionType(*node.expression);
|
const ShaderAst::ExpressionType* expressionType = GetExpressionType(*node.expression);
|
||||||
assert(IsVectorType(expressionType));
|
assert(expressionType);
|
||||||
|
assert(IsVectorType(*expressionType));
|
||||||
|
|
||||||
SwizzledPointer swizzledPointer;
|
SwizzledPointer swizzledPointer;
|
||||||
swizzledPointer.pointerId = pointer.pointerId;
|
swizzledPointer.pointerId = pointer.pointerId;
|
||||||
swizzledPointer.storage = pointer.storage;
|
swizzledPointer.storage = pointer.storage;
|
||||||
swizzledPointer.swizzledType = std::get<ShaderAst::VectorType>(expressionType);
|
swizzledPointer.swizzledType = std::get<ShaderAst::VectorType>(*expressionType);
|
||||||
swizzledPointer.componentCount = node.componentCount;
|
swizzledPointer.componentCount = node.componentCount;
|
||||||
swizzledPointer.swizzleIndices = node.components;
|
swizzledPointer.swizzleIndices = node.components;
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -98,7 +98,7 @@ namespace Nz
|
||||||
for (const auto& parameter : node.parameters)
|
for (const auto& parameter : node.parameters)
|
||||||
{
|
{
|
||||||
auto& var = func.variables.emplace_back();
|
auto& var = func.variables.emplace_back();
|
||||||
var.typeId = m_constantCache.Register(*m_constantCache.BuildPointerType(GetExpressionType(*parameter), SpirvStorageClass::Function));
|
var.typeId = m_constantCache.Register(*m_constantCache.BuildPointerType(*GetExpressionType(*parameter), SpirvStorageClass::Function));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue