Shader: Add support for logical and/or

This commit is contained in:
Jérôme Leclercq 2021-07-07 15:23:39 +02:00
parent ea899e4361
commit 72edff30c7
9 changed files with 192 additions and 76 deletions

View File

@ -34,16 +34,18 @@ namespace Nz
enum class BinaryType enum class BinaryType
{ {
Add, //< + Add, //< +
CompEq, //< == CompEq, //< ==
CompGe, //< >= CompGe, //< >=
CompGt, //< > CompGt, //< >
CompLe, //< <= CompLe, //< <=
CompLt, //< < CompLt, //< <
CompNe, //< <= CompNe, //< <=
Divide, //< / Divide, //< /
Multiply, //< * Multiply, //< *
Subtract, //< - LogicalAnd, //< &&
LogicalOr, //< ||
Subtract, //< -
}; };
enum class BuiltinEntry enum class BuiltinEntry

View File

@ -36,6 +36,8 @@ NAZARA_SHADERLANG_TOKEN(If)
NAZARA_SHADERLANG_TOKEN(LessThan) NAZARA_SHADERLANG_TOKEN(LessThan)
NAZARA_SHADERLANG_TOKEN(LessThanEqual) NAZARA_SHADERLANG_TOKEN(LessThanEqual)
NAZARA_SHADERLANG_TOKEN(Let) NAZARA_SHADERLANG_TOKEN(Let)
NAZARA_SHADERLANG_TOKEN(LogicalAnd)
NAZARA_SHADERLANG_TOKEN(LogicalOr)
NAZARA_SHADERLANG_TOKEN(Multiply) NAZARA_SHADERLANG_TOKEN(Multiply)
NAZARA_SHADERLANG_TOKEN(Minus) NAZARA_SHADERLANG_TOKEN(Minus)
NAZARA_SHADERLANG_TOKEN(NotEqual) NAZARA_SHADERLANG_TOKEN(NotEqual)

View File

@ -152,6 +152,44 @@ namespace Nz::ShaderAst
using Op = BinaryCompNe<T1, T2>; using Op = BinaryCompNe<T1, T2>;
}; };
// LogicalAnd
template<typename T1, typename T2>
struct BinaryLogicalAndBase
{
std::unique_ptr<ConstantExpression> operator()(const T1& lhs, const T2& rhs)
{
return ShaderBuilder::Constant(lhs && rhs);
}
};
template<typename T1, typename T2>
struct BinaryLogicalAnd;
template<typename T1, typename T2>
struct BinaryConstantPropagation<BinaryType::LogicalAnd, T1, T2>
{
using Op = BinaryLogicalAnd<T1, T2>;
};
// LogicalOr
template<typename T1, typename T2>
struct BinaryLogicalOrBase
{
std::unique_ptr<ConstantExpression> operator()(const T1& lhs, const T2& rhs)
{
return ShaderBuilder::Constant(lhs || rhs);
}
};
template<typename T1, typename T2>
struct BinaryLogicalOr;
template<typename T1, typename T2>
struct BinaryConstantPropagation<BinaryType::LogicalOr, T1, T2>
{
using Op = BinaryLogicalOr<T1, T2>;
};
// Addition // Addition
template<typename T1, typename T2> template<typename T1, typename T2>
struct BinaryAdditionBase struct BinaryAdditionBase
@ -325,7 +363,6 @@ namespace Nz::ShaderAst
EnableOptimisation(BinaryCompEq, Vector3i32, Vector3i32); EnableOptimisation(BinaryCompEq, Vector3i32, Vector3i32);
EnableOptimisation(BinaryCompEq, Vector4i32, Vector4i32); EnableOptimisation(BinaryCompEq, Vector4i32, Vector4i32);
EnableOptimisation(BinaryCompGe, bool, bool);
EnableOptimisation(BinaryCompGe, double, double); EnableOptimisation(BinaryCompGe, double, double);
EnableOptimisation(BinaryCompGe, float, float); EnableOptimisation(BinaryCompGe, float, float);
EnableOptimisation(BinaryCompGe, Int32, Int32); EnableOptimisation(BinaryCompGe, Int32, Int32);
@ -336,7 +373,6 @@ namespace Nz::ShaderAst
EnableOptimisation(BinaryCompGe, Vector3i32, Vector3i32); EnableOptimisation(BinaryCompGe, Vector3i32, Vector3i32);
EnableOptimisation(BinaryCompGe, Vector4i32, Vector4i32); EnableOptimisation(BinaryCompGe, Vector4i32, Vector4i32);
EnableOptimisation(BinaryCompGt, bool, bool);
EnableOptimisation(BinaryCompGt, double, double); EnableOptimisation(BinaryCompGt, double, double);
EnableOptimisation(BinaryCompGt, float, float); EnableOptimisation(BinaryCompGt, float, float);
EnableOptimisation(BinaryCompGt, Int32, Int32); EnableOptimisation(BinaryCompGt, Int32, Int32);
@ -347,7 +383,6 @@ namespace Nz::ShaderAst
EnableOptimisation(BinaryCompGt, Vector3i32, Vector3i32); EnableOptimisation(BinaryCompGt, Vector3i32, Vector3i32);
EnableOptimisation(BinaryCompGt, Vector4i32, Vector4i32); EnableOptimisation(BinaryCompGt, Vector4i32, Vector4i32);
EnableOptimisation(BinaryCompLe, bool, bool);
EnableOptimisation(BinaryCompLe, double, double); EnableOptimisation(BinaryCompLe, double, double);
EnableOptimisation(BinaryCompLe, float, float); EnableOptimisation(BinaryCompLe, float, float);
EnableOptimisation(BinaryCompLe, Int32, Int32); EnableOptimisation(BinaryCompLe, Int32, Int32);
@ -358,7 +393,6 @@ namespace Nz::ShaderAst
EnableOptimisation(BinaryCompLe, Vector3i32, Vector3i32); EnableOptimisation(BinaryCompLe, Vector3i32, Vector3i32);
EnableOptimisation(BinaryCompLe, Vector4i32, Vector4i32); EnableOptimisation(BinaryCompLe, Vector4i32, Vector4i32);
EnableOptimisation(BinaryCompLt, bool, bool);
EnableOptimisation(BinaryCompLt, double, double); EnableOptimisation(BinaryCompLt, double, double);
EnableOptimisation(BinaryCompLt, float, float); EnableOptimisation(BinaryCompLt, float, float);
EnableOptimisation(BinaryCompLt, Int32, Int32); EnableOptimisation(BinaryCompLt, Int32, Int32);
@ -380,6 +414,9 @@ namespace Nz::ShaderAst
EnableOptimisation(BinaryCompNe, Vector3i32, Vector3i32); EnableOptimisation(BinaryCompNe, Vector3i32, Vector3i32);
EnableOptimisation(BinaryCompNe, Vector4i32, Vector4i32); EnableOptimisation(BinaryCompNe, Vector4i32, Vector4i32);
EnableOptimisation(BinaryLogicalAnd, bool, bool);
EnableOptimisation(BinaryLogicalOr, bool, bool);
EnableOptimisation(BinaryAddition, double, double); EnableOptimisation(BinaryAddition, double, double);
EnableOptimisation(BinaryAddition, float, float); EnableOptimisation(BinaryAddition, float, float);
EnableOptimisation(BinaryAddition, Int32, Int32); EnableOptimisation(BinaryAddition, Int32, Int32);
@ -583,6 +620,14 @@ namespace Nz::ShaderAst
case BinaryType::CompNe: case BinaryType::CompNe:
optimized = PropagateBinaryConstant<BinaryType::CompNe>(std::move(lhsConstant), std::move(rhsConstant)); optimized = PropagateBinaryConstant<BinaryType::CompNe>(std::move(lhsConstant), std::move(rhsConstant));
break; break;
case BinaryType::LogicalAnd:
optimized = PropagateBinaryConstant<BinaryType::LogicalAnd>(std::move(lhsConstant), std::move(rhsConstant));
break;
case BinaryType::LogicalOr:
optimized = PropagateBinaryConstant<BinaryType::LogicalOr>(std::move(lhsConstant), std::move(rhsConstant));
break;
} }
if (optimized) if (optimized)

View File

@ -314,7 +314,18 @@ namespace Nz::ShaderAst
default: default:
throw AstError{ "incompatible types" }; throw AstError{ "incompatible types" };
} }
break;
} }
case BinaryType::LogicalAnd:
case BinaryType::LogicalOr:
if (leftType != PrimitiveType::Boolean)
throw AstError{ "logical and/or are only supported on booleans" };
TypeMustMatch(clone->left, clone->right);
clone->cachedExpressionType = PrimitiveType::Boolean;
break;
} }
} }
else if (IsMatrixType(leftExprType)) else if (IsMatrixType(leftExprType))
@ -363,7 +374,13 @@ namespace Nz::ShaderAst
} }
else else
throw AstError{ "incompatible types" }; throw AstError{ "incompatible types" };
break;
} }
case BinaryType::LogicalAnd:
case BinaryType::LogicalOr:
throw AstError{ "logical and/or are only supported on booleans" };
} }
} }
else if (IsVectorType(leftExprType)) else if (IsVectorType(leftExprType))
@ -402,7 +419,13 @@ namespace Nz::ShaderAst
} }
else else
throw AstError{ "incompatible types" }; throw AstError{ "incompatible types" };
break;
} }
case BinaryType::LogicalAnd:
case BinaryType::LogicalOr:
throw AstError{ "logical and/or are only supported on booleans" };
} }
} }

View File

@ -763,6 +763,9 @@ namespace Nz
case ShaderAst::BinaryType::CompLe: Append(" <= "); break; case ShaderAst::BinaryType::CompLe: Append(" <= "); break;
case ShaderAst::BinaryType::CompLt: Append(" < "); break; case ShaderAst::BinaryType::CompLt: Append(" < "); break;
case ShaderAst::BinaryType::CompNe: Append(" != "); break; case ShaderAst::BinaryType::CompNe: Append(" != "); break;
case ShaderAst::BinaryType::LogicalAnd: Append(" && "); break;
case ShaderAst::BinaryType::LogicalOr: Append(" || "); break;
} }
Visit(node.right, true); Visit(node.right, true);

View File

@ -584,17 +584,20 @@ namespace Nz
switch (node.op) switch (node.op)
{ {
case ShaderAst::BinaryType::Add: Append(" + "); break; case ShaderAst::BinaryType::Add: Append(" + "); break;
case ShaderAst::BinaryType::Subtract: Append(" - "); break; case ShaderAst::BinaryType::Subtract: Append(" - "); break;
case ShaderAst::BinaryType::Multiply: Append(" * "); break; case ShaderAst::BinaryType::Multiply: Append(" * "); break;
case ShaderAst::BinaryType::Divide: Append(" / "); break; case ShaderAst::BinaryType::Divide: Append(" / "); break;
case ShaderAst::BinaryType::CompEq: Append(" == "); break; case ShaderAst::BinaryType::CompEq: Append(" == "); break;
case ShaderAst::BinaryType::CompGe: Append(" >= "); break; case ShaderAst::BinaryType::CompGe: Append(" >= "); break;
case ShaderAst::BinaryType::CompGt: Append(" > "); break; case ShaderAst::BinaryType::CompGt: Append(" > "); break;
case ShaderAst::BinaryType::CompLe: Append(" <= "); break; case ShaderAst::BinaryType::CompLe: Append(" <= "); break;
case ShaderAst::BinaryType::CompLt: Append(" < "); break; case ShaderAst::BinaryType::CompLt: Append(" < "); break;
case ShaderAst::BinaryType::CompNe: Append(" != "); break; case ShaderAst::BinaryType::CompNe: Append(" != "); break;
case ShaderAst::BinaryType::LogicalAnd: Append(" && "); break;
case ShaderAst::BinaryType::LogicalOr: Append(" || "); break;
} }
Visit(node.right, true); Visit(node.right, true);

View File

@ -246,6 +246,34 @@ namespace Nz::ShaderLang
break; break;
} }
case '|':
{
char next = Peek();
if (next == '|')
{
currentPos++;
tokenType = TokenType::LogicalOr;
}
else
throw UnrecognizedToken{}; //< TODO: Add BOR (a | b)
break;
}
case '&':
{
char next = Peek();
if (next == '&')
{
currentPos++;
tokenType = TokenType::LogicalAnd;
}
else
throw UnrecognizedToken{}; //< TODO: Add BAND (a & b)
break;
}
case '<': case '<':
{ {
char next = Peek(); char next = Peek();

View File

@ -917,6 +917,8 @@ namespace Nz::ShaderLang
case TokenType::Equal: binaryType = ShaderAst::BinaryType::CompEq; break; case TokenType::Equal: binaryType = ShaderAst::BinaryType::CompEq; break;
case TokenType::LessThan: binaryType = ShaderAst::BinaryType::CompLt; break; case TokenType::LessThan: binaryType = ShaderAst::BinaryType::CompLt; break;
case TokenType::LessThanEqual: binaryType = ShaderAst::BinaryType::CompLe; break; case TokenType::LessThanEqual: binaryType = ShaderAst::BinaryType::CompLe; break;
case TokenType::LogicalAnd: binaryType = ShaderAst::BinaryType::LogicalAnd; break;
case TokenType::LogicalOr: binaryType = ShaderAst::BinaryType::LogicalOr; break;
case TokenType::GreaterThan: binaryType = ShaderAst::BinaryType::CompLt; break; case TokenType::GreaterThan: binaryType = ShaderAst::BinaryType::CompLt; break;
case TokenType::GreaterThanEqual: binaryType = ShaderAst::BinaryType::CompLe; break; case TokenType::GreaterThanEqual: binaryType = ShaderAst::BinaryType::CompLe; break;
case TokenType::Minus: binaryType = ShaderAst::BinaryType::Subtract; break; case TokenType::Minus: binaryType = ShaderAst::BinaryType::Subtract; break;
@ -1140,6 +1142,8 @@ namespace Nz::ShaderLang
case TokenType::Equal: return 50; case TokenType::Equal: return 50;
case TokenType::LessThan: return 40; case TokenType::LessThan: return 40;
case TokenType::LessThanEqual: return 40; case TokenType::LessThanEqual: return 40;
case TokenType::LogicalAnd: return 120;
case TokenType::LogicalOr: return 140;
case TokenType::GreaterThan: return 40; case TokenType::GreaterThan: return 40;
case TokenType::GreaterThanEqual: return 40; case TokenType::GreaterThanEqual: return 40;
case TokenType::Multiply: return 80; case TokenType::Multiply: return 80;

View File

@ -139,6 +139,60 @@ namespace Nz
break; break;
} }
case ShaderAst::BinaryType::Multiply:
{
switch (leftTypeBase)
{
case ShaderAst::PrimitiveType::Float32:
{
if (IsPrimitiveType(leftType))
{
// Handle float * matrix|vector as matrix|vector * float
if (IsMatrixType(rightType))
{
swapOperands = true;
return SpirvOp::OpMatrixTimesScalar;
}
else if (IsVectorType(rightType))
{
swapOperands = true;
return SpirvOp::OpVectorTimesScalar;
}
}
else if (IsPrimitiveType(rightType))
{
if (IsMatrixType(leftType))
return SpirvOp::OpMatrixTimesScalar;
else if (IsVectorType(leftType))
return SpirvOp::OpVectorTimesScalar;
}
else if (IsMatrixType(leftType))
{
if (IsMatrixType(rightType))
return SpirvOp::OpMatrixTimesMatrix;
else if (IsVectorType(rightType))
return SpirvOp::OpMatrixTimesVector;
}
else if (IsMatrixType(rightType))
{
assert(IsVectorType(leftType));
return SpirvOp::OpVectorTimesMatrix;
}
return SpirvOp::OpFMul;
}
case ShaderAst::PrimitiveType::Int32:
case ShaderAst::PrimitiveType::UInt32:
return SpirvOp::OpIMul;
default:
break;
}
break;
}
case ShaderAst::BinaryType::CompEq: case ShaderAst::BinaryType::CompEq:
{ {
switch (leftTypeBase) switch (leftTypeBase)
@ -255,59 +309,11 @@ namespace Nz
break; break;
} }
case ShaderAst::BinaryType::Multiply: case ShaderAst::BinaryType::LogicalAnd:
{ return SpirvOp::OpLogicalAnd;
switch (leftTypeBase)
{
case ShaderAst::PrimitiveType::Float32:
{
if (IsPrimitiveType(leftType))
{
// Handle float * matrix|vector as matrix|vector * float
if (IsMatrixType(rightType))
{
swapOperands = true;
return SpirvOp::OpMatrixTimesScalar;
}
else if (IsVectorType(rightType))
{
swapOperands = true;
return SpirvOp::OpVectorTimesScalar;
}
}
else if (IsPrimitiveType(rightType))
{
if (IsMatrixType(leftType))
return SpirvOp::OpMatrixTimesScalar;
else if (IsVectorType(leftType))
return SpirvOp::OpVectorTimesScalar;
}
else if (IsMatrixType(leftType))
{
if (IsMatrixType(rightType))
return SpirvOp::OpMatrixTimesMatrix;
else if (IsVectorType(rightType))
return SpirvOp::OpMatrixTimesVector;
}
else if (IsMatrixType(rightType))
{
assert(IsVectorType(leftType));
return SpirvOp::OpVectorTimesMatrix;
}
return SpirvOp::OpFMul; case ShaderAst::BinaryType::LogicalOr:
} return SpirvOp::OpLogicalOr;
case ShaderAst::PrimitiveType::Int32:
case ShaderAst::PrimitiveType::UInt32:
return SpirvOp::OpIMul;
default:
break;
}
break;
}
} }
assert(false); assert(false);