Shader: Fix parsing of unary/dot/indices/and/or
This commit is contained in:
parent
4fa3de519c
commit
a54f70fd24
|
|
@ -85,6 +85,10 @@ namespace Nz::ShaderLang
|
|||
void ParseModuleStatement(std::vector<ShaderAst::ExprValue> attributes);
|
||||
void ParseVariableDeclaration(std::string& name, ShaderAst::ExpressionValue<ShaderAst::ExpressionType>& type, ShaderAst::ExpressionPtr& initialValue);
|
||||
|
||||
ShaderAst::ExpressionPtr BuildIdentifierAccess(ShaderAst::ExpressionPtr lhs, ShaderAst::ExpressionPtr rhs);
|
||||
ShaderAst::ExpressionPtr BuildIndexAccess(ShaderAst::ExpressionPtr lhs, ShaderAst::ExpressionPtr rhs);
|
||||
ShaderAst::ExpressionPtr BuildBinary(ShaderAst::BinaryType binaryType, ShaderAst::ExpressionPtr lhs, ShaderAst::ExpressionPtr rhs);
|
||||
|
||||
// Statements
|
||||
ShaderAst::StatementPtr ParseAliasDeclaration();
|
||||
ShaderAst::StatementPtr ParseBranchStatement();
|
||||
|
|
@ -110,10 +114,10 @@ namespace Nz::ShaderLang
|
|||
ShaderAst::ExpressionPtr ParseBinOpRhs(int exprPrecedence, ShaderAst::ExpressionPtr lhs);
|
||||
ShaderAst::ExpressionPtr ParseConstSelectExpression();
|
||||
ShaderAst::ExpressionPtr ParseExpression();
|
||||
std::vector<ShaderAst::ExpressionPtr> ParseExpressionList(TokenType terminationToken);
|
||||
ShaderAst::ExpressionPtr ParseFloatingPointExpression();
|
||||
ShaderAst::ExpressionPtr ParseIdentifier();
|
||||
ShaderAst::ExpressionPtr ParseIntegerExpression();
|
||||
std::vector<ShaderAst::ExpressionPtr> ParseParameters();
|
||||
ShaderAst::ExpressionPtr ParseParenthesisExpression();
|
||||
ShaderAst::ExpressionPtr ParsePrimaryExpression();
|
||||
ShaderAst::ExpressionPtr ParseStringExpression();
|
||||
|
|
|
|||
|
|
@ -390,6 +390,24 @@ namespace Nz::ShaderLang
|
|||
Expect(Advance(), TokenType::Semicolon);
|
||||
}
|
||||
|
||||
ShaderAst::ExpressionPtr Parser::BuildIdentifierAccess(ShaderAst::ExpressionPtr lhs, ShaderAst::ExpressionPtr rhs)
|
||||
{
|
||||
if (rhs->GetType() == ShaderAst::NodeType::IdentifierExpression)
|
||||
return ShaderBuilder::AccessMember(std::move(lhs), { std::move(SafeCast<ShaderAst::IdentifierExpression&>(*rhs).identifier) });
|
||||
else
|
||||
return BuildIndexAccess(std::move(lhs), std::move(rhs));
|
||||
}
|
||||
|
||||
ShaderAst::ExpressionPtr Parser::BuildIndexAccess(ShaderAst::ExpressionPtr lhs, ShaderAst::ExpressionPtr rhs)
|
||||
{
|
||||
return ShaderBuilder::AccessIndex(std::move(lhs), std::move(rhs));
|
||||
}
|
||||
|
||||
ShaderAst::ExpressionPtr Parser::BuildBinary(ShaderAst::BinaryType binaryType, ShaderAst::ExpressionPtr lhs, ShaderAst::ExpressionPtr rhs)
|
||||
{
|
||||
return ShaderBuilder::Binary(binaryType, std::move(lhs), std::move(rhs));
|
||||
}
|
||||
|
||||
ShaderAst::StatementPtr Parser::ParseAliasDeclaration()
|
||||
{
|
||||
Expect(Advance(), TokenType::Alias);
|
||||
|
|
@ -1124,59 +1142,25 @@ namespace Nz::ShaderLang
|
|||
if (tokenPrecedence < exprPrecedence)
|
||||
return lhs;
|
||||
|
||||
bool c = false;
|
||||
while (currentTokenType == TokenType::Dot || currentTokenType == TokenType::OpenSquareBracket)
|
||||
{
|
||||
c = true;
|
||||
|
||||
if (currentTokenType == TokenType::Dot)
|
||||
{
|
||||
std::unique_ptr<ShaderAst::AccessIdentifierExpression> accessMemberNode = std::make_unique<ShaderAst::AccessIdentifierExpression>();
|
||||
accessMemberNode->expr = std::move(lhs);
|
||||
|
||||
do
|
||||
{
|
||||
Consume();
|
||||
|
||||
accessMemberNode->identifiers.push_back(ParseIdentifierAsName());
|
||||
} while (Peek().type == TokenType::Dot);
|
||||
|
||||
lhs = std::move(accessMemberNode);
|
||||
}
|
||||
else
|
||||
{
|
||||
assert(currentTokenType == TokenType::OpenSquareBracket);
|
||||
|
||||
std::unique_ptr<ShaderAst::AccessIndexExpression> indexNode = std::make_unique<ShaderAst::AccessIndexExpression>();
|
||||
indexNode->expr = std::move(lhs);
|
||||
|
||||
do
|
||||
{
|
||||
Consume();
|
||||
|
||||
indexNode->indices.push_back(ParseExpression());
|
||||
}
|
||||
while (Peek().type == TokenType::Comma);
|
||||
|
||||
Expect(Advance(), TokenType::ClosingSquareBracket);
|
||||
|
||||
lhs = std::move(indexNode);
|
||||
}
|
||||
|
||||
currentTokenType = Peek().type;
|
||||
}
|
||||
|
||||
if (currentTokenType == TokenType::OpenParenthesis)
|
||||
{
|
||||
// Function call
|
||||
auto parameters = ParseParameters();
|
||||
lhs = ShaderBuilder::CallFunction(std::move(lhs), std::move(parameters));
|
||||
Consume();
|
||||
|
||||
c = true;
|
||||
// Function call
|
||||
auto parameters = ParseExpressionList(TokenType::ClosingParenthesis);
|
||||
lhs = ShaderBuilder::CallFunction(std::move(lhs), std::move(parameters));
|
||||
continue;
|
||||
}
|
||||
|
||||
if (c)
|
||||
if (currentTokenType == TokenType::OpenSquareBracket)
|
||||
{
|
||||
Consume();
|
||||
|
||||
// Indices
|
||||
auto parameters = ParseExpressionList(TokenType::ClosingSquareBracket);
|
||||
lhs = ShaderBuilder::AccessIndex(std::move(lhs), std::move(parameters));
|
||||
continue;
|
||||
}
|
||||
|
||||
Consume();
|
||||
ShaderAst::ExpressionPtr rhs = ParsePrimaryExpression();
|
||||
|
|
@ -1187,28 +1171,30 @@ namespace Nz::ShaderLang
|
|||
if (tokenPrecedence < nextTokenPrecedence)
|
||||
rhs = ParseBinOpRhs(tokenPrecedence + 1, std::move(rhs));
|
||||
|
||||
ShaderAst::BinaryType binaryType;
|
||||
lhs = [&]
|
||||
{
|
||||
switch (currentTokenType)
|
||||
{
|
||||
case TokenType::Divide: binaryType = ShaderAst::BinaryType::Divide; break;
|
||||
case TokenType::Equal: binaryType = ShaderAst::BinaryType::CompEq; break;
|
||||
case TokenType::LessThan: binaryType = ShaderAst::BinaryType::CompLt; break;
|
||||
case TokenType::LessThanEqual: binaryType = ShaderAst::BinaryType::CompLe; break;
|
||||
case TokenType::LogicalAnd: binaryType = ShaderAst::BinaryType::LogicalAnd; break;
|
||||
case TokenType::LogicalOr: binaryType = ShaderAst::BinaryType::LogicalOr; break;
|
||||
case TokenType::GreaterThan: binaryType = ShaderAst::BinaryType::CompGt; break;
|
||||
case TokenType::GreaterThanEqual: binaryType = ShaderAst::BinaryType::CompGe; break;
|
||||
case TokenType::Minus: binaryType = ShaderAst::BinaryType::Subtract; break;
|
||||
case TokenType::Multiply: binaryType = ShaderAst::BinaryType::Multiply; break;
|
||||
case TokenType::NotEqual: binaryType = ShaderAst::BinaryType::CompNe; break;
|
||||
case TokenType::Plus: binaryType = ShaderAst::BinaryType::Add; break;
|
||||
case TokenType::Dot:
|
||||
return BuildIdentifierAccess(std::move(lhs), std::move(rhs));
|
||||
|
||||
case TokenType::Divide: return BuildBinary(ShaderAst::BinaryType::Divide, std::move(lhs), std::move(rhs));
|
||||
case TokenType::Equal: return BuildBinary(ShaderAst::BinaryType::CompEq, std::move(lhs), std::move(rhs));
|
||||
case TokenType::LessThan: return BuildBinary(ShaderAst::BinaryType::CompLt, std::move(lhs), std::move(rhs));
|
||||
case TokenType::LessThanEqual: return BuildBinary(ShaderAst::BinaryType::CompLe, std::move(lhs), std::move(rhs));
|
||||
case TokenType::LogicalAnd: return BuildBinary(ShaderAst::BinaryType::LogicalAnd, std::move(lhs), std::move(rhs));
|
||||
case TokenType::LogicalOr: return BuildBinary(ShaderAst::BinaryType::LogicalOr, std::move(lhs), std::move(rhs));
|
||||
case TokenType::GreaterThan: return BuildBinary(ShaderAst::BinaryType::CompGt, std::move(lhs), std::move(rhs));
|
||||
case TokenType::GreaterThanEqual: return BuildBinary(ShaderAst::BinaryType::CompGe, std::move(lhs), std::move(rhs));
|
||||
case TokenType::Minus: return BuildBinary(ShaderAst::BinaryType::Subtract, std::move(lhs), std::move(rhs));
|
||||
case TokenType::Multiply: return BuildBinary(ShaderAst::BinaryType::Multiply, std::move(lhs), std::move(rhs));
|
||||
case TokenType::NotEqual: return BuildBinary(ShaderAst::BinaryType::CompNe, std::move(lhs), std::move(rhs));
|
||||
case TokenType::Plus: return BuildBinary(ShaderAst::BinaryType::Add, std::move(lhs), std::move(rhs));
|
||||
default:
|
||||
throw UnexpectedToken{};
|
||||
}
|
||||
}
|
||||
|
||||
lhs = ShaderBuilder::Binary(binaryType, std::move(lhs), std::move(rhs));
|
||||
}
|
||||
}();
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1237,6 +1223,24 @@ namespace Nz::ShaderLang
|
|||
return ParseBinOpRhs(0, ParsePrimaryExpression());
|
||||
}
|
||||
|
||||
std::vector<ShaderAst::ExpressionPtr> Parser::ParseExpressionList(TokenType terminationToken)
|
||||
{
|
||||
std::vector<ShaderAst::ExpressionPtr> parameters;
|
||||
bool first = true;
|
||||
while (Peek().type != terminationToken)
|
||||
{
|
||||
if (!first)
|
||||
Expect(Advance(), TokenType::Comma);
|
||||
|
||||
first = false;
|
||||
parameters.push_back(ParseExpression());
|
||||
}
|
||||
|
||||
Expect(Advance(), terminationToken);
|
||||
|
||||
return parameters;
|
||||
}
|
||||
|
||||
ShaderAst::ExpressionPtr Parser::ParseFloatingPointExpression()
|
||||
{
|
||||
const Token& floatingPointToken = Expect(Advance(), TokenType::FloatingPointValue);
|
||||
|
|
@ -1257,26 +1261,6 @@ namespace Nz::ShaderLang
|
|||
return ShaderBuilder::Constant(SafeCast<Int32>(std::get<long long>(integerToken.data))); //< FIXME
|
||||
}
|
||||
|
||||
std::vector<ShaderAst::ExpressionPtr> Parser::ParseParameters()
|
||||
{
|
||||
Expect(Advance(), TokenType::OpenParenthesis);
|
||||
|
||||
std::vector<ShaderAst::ExpressionPtr> parameters;
|
||||
bool first = true;
|
||||
while (Peek().type != TokenType::ClosingParenthesis)
|
||||
{
|
||||
if (!first)
|
||||
Expect(Advance(), TokenType::Comma);
|
||||
|
||||
first = false;
|
||||
parameters.push_back(ParseExpression());
|
||||
}
|
||||
|
||||
Expect(Advance(), TokenType::ClosingParenthesis);
|
||||
|
||||
return parameters;
|
||||
}
|
||||
|
||||
ShaderAst::ExpressionPtr Parser::ParseParenthesisExpression()
|
||||
{
|
||||
Expect(Advance(), TokenType::OpenParenthesis);
|
||||
|
|
@ -1314,7 +1298,7 @@ namespace Nz::ShaderLang
|
|||
case TokenType::Minus:
|
||||
{
|
||||
Consume();
|
||||
ShaderAst::ExpressionPtr expr = ParsePrimaryExpression();
|
||||
ShaderAst::ExpressionPtr expr = ParseExpression();
|
||||
|
||||
return ShaderBuilder::Unary(ShaderAst::UnaryType::Minus, std::move(expr));
|
||||
}
|
||||
|
|
@ -1322,7 +1306,7 @@ namespace Nz::ShaderLang
|
|||
case TokenType::Plus:
|
||||
{
|
||||
Consume();
|
||||
ShaderAst::ExpressionPtr expr = ParsePrimaryExpression();
|
||||
ShaderAst::ExpressionPtr expr = ParseExpression();
|
||||
|
||||
return ShaderBuilder::Unary(ShaderAst::UnaryType::Plus, std::move(expr));
|
||||
}
|
||||
|
|
@ -1330,7 +1314,7 @@ namespace Nz::ShaderLang
|
|||
case TokenType::Not:
|
||||
{
|
||||
Consume();
|
||||
ShaderAst::ExpressionPtr expr = ParsePrimaryExpression();
|
||||
ShaderAst::ExpressionPtr expr = ParseExpression();
|
||||
|
||||
return ShaderBuilder::Unary(ShaderAst::UnaryType::LogicalNot, std::move(expr));
|
||||
}
|
||||
|
|
@ -1404,12 +1388,12 @@ namespace Nz::ShaderLang
|
|||
switch (token)
|
||||
{
|
||||
case TokenType::Divide: return 80;
|
||||
case TokenType::Dot: return 100;
|
||||
case TokenType::Dot: return 150;
|
||||
case TokenType::Equal: return 50;
|
||||
case TokenType::LessThan: return 40;
|
||||
case TokenType::LessThanEqual: return 40;
|
||||
case TokenType::LogicalAnd: return 120;
|
||||
case TokenType::LogicalOr: return 140;
|
||||
case TokenType::LogicalAnd: return 20;
|
||||
case TokenType::LogicalOr: return 10;
|
||||
case TokenType::GreaterThan: return 40;
|
||||
case TokenType::GreaterThanEqual: return 40;
|
||||
case TokenType::Multiply: return 80;
|
||||
|
|
|
|||
|
|
@ -88,6 +88,97 @@ OpStore
|
|||
OpBranch
|
||||
OpLabel
|
||||
OpReturn
|
||||
OpFunctionEnd)");
|
||||
}
|
||||
|
||||
WHEN("using a more complex branch")
|
||||
{
|
||||
std::string_view nzslSource = R"(
|
||||
[nzsl_version("1.0")]
|
||||
module;
|
||||
|
||||
struct inputStruct
|
||||
{
|
||||
value: f32
|
||||
}
|
||||
|
||||
external
|
||||
{
|
||||
[set(0), binding(0)] data: uniform[inputStruct]
|
||||
}
|
||||
|
||||
[entry(frag)]
|
||||
fn main()
|
||||
{
|
||||
let value: f32;
|
||||
if (data.value > 42.0 || data.value <= 50.0 && data.value < 0.0)
|
||||
value = 1.0;
|
||||
else
|
||||
value = 0.0;
|
||||
}
|
||||
)";
|
||||
|
||||
Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource);
|
||||
shaderModule = SanitizeModule(*shaderModule);
|
||||
|
||||
ExpectGLSL(*shaderModule, R"(
|
||||
void main()
|
||||
{
|
||||
float value;
|
||||
if ((data.value > (42.000000)) || ((data.value <= (50.000000)) && (data.value < (0.000000))))
|
||||
{
|
||||
value = 1.000000;
|
||||
}
|
||||
else
|
||||
{
|
||||
value = 0.000000;
|
||||
}
|
||||
|
||||
}
|
||||
)");
|
||||
|
||||
ExpectNZSL(*shaderModule, R"(
|
||||
[entry(frag)]
|
||||
fn main()
|
||||
{
|
||||
let value: f32;
|
||||
if ((data.value > (42.000000)) || ((data.value <= (50.000000)) && (data.value < (0.000000))))
|
||||
{
|
||||
value = 1.000000;
|
||||
}
|
||||
else
|
||||
{
|
||||
value = 0.000000;
|
||||
}
|
||||
|
||||
}
|
||||
)");
|
||||
|
||||
ExpectSPIRV(*shaderModule, R"(
|
||||
OpFunction
|
||||
OpLabel
|
||||
OpVariable
|
||||
OpAccessChain
|
||||
OpLoad
|
||||
OpFOrdGreaterThanEqual
|
||||
OpAccessChain
|
||||
OpLoad
|
||||
OpFOrdLessThanEqual
|
||||
OpAccessChain
|
||||
OpLoad
|
||||
OpFOrdLessThan
|
||||
OpLogicalAnd
|
||||
OpLogicalOr
|
||||
OpSelectionMerge
|
||||
OpBranchConditional
|
||||
OpLabel
|
||||
OpStore
|
||||
OpBranch
|
||||
OpLabel
|
||||
OpStore
|
||||
OpBranch
|
||||
OpLabel
|
||||
OpReturn
|
||||
OpFunctionEnd)");
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ fn GetValue() -> f32
|
|||
fn main() -> FragOut
|
||||
{
|
||||
let output: FragOut;
|
||||
output.value = GetValue();
|
||||
output.value = -GetValue();
|
||||
|
||||
return output;
|
||||
}
|
||||
|
|
@ -49,7 +49,7 @@ layout(location = 0) out float _NzOut_value;
|
|||
void main()
|
||||
{
|
||||
FragOut output_;
|
||||
output_.value = GetValue();
|
||||
output_.value = -GetValue();
|
||||
|
||||
_NzOut_value = output_.value;
|
||||
return;
|
||||
|
|
@ -66,7 +66,7 @@ fn GetValue() -> f32
|
|||
fn main() -> FragOut
|
||||
{
|
||||
let output: FragOut;
|
||||
output.value = GetValue();
|
||||
output.value = -GetValue();
|
||||
return output;
|
||||
}
|
||||
)");
|
||||
|
|
@ -80,6 +80,7 @@ OpFunction
|
|||
OpLabel
|
||||
OpVariable
|
||||
OpFunctionCall
|
||||
OpFNegate
|
||||
OpAccessChain
|
||||
OpStore
|
||||
OpLoad
|
||||
|
|
|
|||
Loading…
Reference in New Issue