Shader: Add support for pow intrinsic

This commit is contained in:
Jérôme Leclercq 2021-06-05 01:29:36 +02:00
parent 2d502775a6
commit 50bf26d92f
5 changed files with 48 additions and 1 deletions

View File

@ -93,6 +93,7 @@ namespace Nz
Length = 3,
Max = 4,
Min = 5,
Pow = 6,
SampleTexture = 2,
};

View File

@ -59,9 +59,10 @@ namespace Nz::ShaderAst
{
RegisterIntrinsic("cross", IntrinsicType::CrossProduct);
RegisterIntrinsic("dot", IntrinsicType::DotProduct);
RegisterIntrinsic("length", IntrinsicType::Length);
RegisterIntrinsic("max", IntrinsicType::Max);
RegisterIntrinsic("min", IntrinsicType::Min);
RegisterIntrinsic("length", IntrinsicType::Length);
RegisterIntrinsic("pow", IntrinsicType::Pow);
// Collect function name and their types
if (nodePtr->GetType() == NodeType::MultiStatement)
@ -1289,6 +1290,7 @@ namespace Nz::ShaderAst
case IntrinsicType::DotProduct:
case IntrinsicType::Max:
case IntrinsicType::Min:
case IntrinsicType::Pow:
{
if (node.parameters.size() != 2)
throw AstError { "Expected two parameters" };
@ -1379,6 +1381,20 @@ namespace Nz::ShaderAst
break;
}
case IntrinsicType::Pow:
{
const ExpressionType& type = GetExpressionType(*node.parameters.front());
if (!IsPrimitiveType(type) && !IsVectorType(type))
throw AstError{ "pow only works with primitive and vector types" };
if ((IsPrimitiveType(type) && std::get<PrimitiveType>(type) != PrimitiveType::Float32) ||
(IsVectorType(type) && std::get<VectorType>(type).type != PrimitiveType::Float32))
throw AstError{ "pow only works with floating-point primitive or vectors" };
node.cachedExpressionType = type;
break;
}
case IntrinsicType::SampleTexture:
{
node.cachedExpressionType = VectorType{ 4, std::get<SamplerType>(GetExpressionType(*node.parameters.front())).sampledType };

View File

@ -1020,6 +1020,10 @@ namespace Nz
Append("min");
break;
case ShaderAst::IntrinsicType::Pow:
Append("pow");
break;
case ShaderAst::IntrinsicType::SampleTexture:
Append("texture");
break;

View File

@ -799,6 +799,31 @@ namespace Nz
break;
}
case ShaderAst::IntrinsicType::Pow:
{
UInt32 glslInstructionSet = m_writer.GetExtendedInstructionSet("GLSL.std.450");
const ShaderAst::ExpressionType& parameterType = GetExpressionType(*node.parameters[0]);
assert(IsPrimitiveType(parameterType) || IsVectorType(parameterType));
UInt32 typeId = m_writer.GetTypeId(parameterType);
ShaderAst::PrimitiveType basicType;
if (IsPrimitiveType(parameterType))
basicType = std::get<ShaderAst::PrimitiveType>(parameterType);
else if (IsVectorType(parameterType))
basicType = std::get<ShaderAst::VectorType>(parameterType).type;
else
throw std::runtime_error("unexpected expression type");
UInt32 firstParam = EvaluateExpression(node.parameters[0]);
UInt32 secondParam = EvaluateExpression(node.parameters[1]);
UInt32 resultId = m_writer.AllocateResultId();
m_currentBlock->Append(SpirvOp::OpExtInst, typeId, resultId, glslInstructionSet, GLSLstd450Pow, firstParam, secondParam);
PushResultId(resultId);
break;
}
case ShaderAst::IntrinsicType::SampleTexture:
{
UInt32 typeId = m_writer.GetTypeId(ShaderAst::VectorType{4, ShaderAst::PrimitiveType::Float32});

View File

@ -324,6 +324,7 @@ namespace Nz
case ShaderAst::IntrinsicType::Length:
case ShaderAst::IntrinsicType::Max:
case ShaderAst::IntrinsicType::Min:
case ShaderAst::IntrinsicType::Pow:
extInsts.emplace("GLSL.std.450");
break;