From 937c2ffbcf22935b67fbc90c2eedfb37c185b458 Mon Sep 17 00:00:00 2001 From: David Baker Effendi Date: Tue, 7 Jan 2025 15:24:10 +0200 Subject: [PATCH] [ruby] In Pattern Variable Scoping Fix (#5208) Fixed an issue where the variables that the pattern extracted to were interpreted as fields instead of local variables. Also makes sure the pattern match call happens on the original expression and not on the LHS match variable. --- .../AstForControlStructuresCreator.scala | 17 +++++++---------- .../astcreation/AstForExpressionsCreator.scala | 10 +++++++++- .../astcreation/RubyIntermediateAst.scala | 4 +++- .../joern/rubysrc2cpg/querying/CaseTests.scala | 4 ++-- 4 files changed, 21 insertions(+), 14 deletions(-) diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForControlStructuresCreator.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForControlStructuresCreator.scala index 38ebb01b72b7..08757227bc30 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForControlStructuresCreator.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForControlStructuresCreator.scala @@ -299,17 +299,13 @@ trait AstForControlStructuresCreator(implicit withSchemaValidation: ValidationMo case x: ArrayPattern => val condition = expr.map(e => BinaryExpression(x, "===", e)(x.span)).getOrElse(inClause.pattern) val body = inClause.body + val variables = x.children.collect { case x: MatchVariable => x } - val variables = x.children.collect { case x: MatchVariable => - x - } - - val conditionBody = if (variables.nonEmpty) { - StatementList(variables.map { x => - val lhs = SimpleIdentifier()(x.span) - SingleAssignment(lhs, "=", x)( + val conditionBody = if (variables.nonEmpty && expr.isDefined) { + StatementList(variables.map { lhs => + SingleAssignment(lhs, "=", MatchVariable()(expr.get.span))( inClause.span - .spanStart(s"${lhs.span.text} = ${RubyOperators.arrayPatternMatch}(${lhs.span.text})") + .spanStart(s"${lhs.span.text} = ${RubyOperators.arrayPatternMatch}(${expr.get.text})") ) } :+ body)(body.span) } else { @@ -317,7 +313,8 @@ trait AstForControlStructuresCreator(implicit withSchemaValidation: ValidationMo } (condition, conditionBody) - case x => (x, inClause.body) + case x => + (x, inClause.body) } val conditional = IfExpression( diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForExpressionsCreator.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForExpressionsCreator.scala index ecddf5fa03d4..44499e57e6a5 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForExpressionsCreator.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForExpressionsCreator.scala @@ -520,6 +520,11 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { val asts = astsForStatement(x.multipleAssignment) val call = callNode(node, code(node), op, op, DispatchTypes.STATIC_DISPATCH) return callAst(call, asts :+ rhsAst) + case x: MatchVariable => + handleVariableOccurrence(x.toSimpleIdentifier) // Create local variable under this scope + val matchIden = astForExpression(x.toSimpleIdentifier) + val call = callNode(node, code(node), op, op, DispatchTypes.STATIC_DISPATCH) + return callAst(call, matchIden :: rhsAst :: Nil) case _ => astForExpression(node.lhs) } @@ -618,7 +623,10 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { protected def astForArrayPattern(node: ArrayPattern): Ast = { val callNode_ = callNode(node, code(node), Operators.arrayInitializer, Operators.arrayInitializer, DispatchTypes.STATIC_DISPATCH) - val childrenAst = node.children.map(astForExpression) + val childrenAst = node.children.map { + case x: MatchVariable => astForExpression(SimpleIdentifier()(x.span)) + case x => astForExpression(x) + } callAst(callNode_, childrenAst) } diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/RubyIntermediateAst.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/RubyIntermediateAst.scala index 406893179b8e..ddbd417b900d 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/RubyIntermediateAst.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/RubyIntermediateAst.scala @@ -348,7 +348,9 @@ object RubyIntermediateAst { final case class ArrayPattern(children: List[RubyExpression])(span: TextSpan) extends RubyExpression(span) - final case class MatchVariable()(span: TextSpan) extends RubyExpression(span) + final case class MatchVariable()(span: TextSpan) extends RubyExpression(span) { + def toSimpleIdentifier: SimpleIdentifier = SimpleIdentifier()(span) + } final case class NextExpression()(span: TextSpan) extends RubyExpression(span) with ControlFlowStatement diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/CaseTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/CaseTests.scala index 6a05a907b06e..23bdcc2cdb06 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/CaseTests.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/CaseTests.scala @@ -202,7 +202,7 @@ class CaseTests extends RubyCode2CpgFixture { lhs.name shouldBe "result" rhs.methodFullName shouldBe RubyOperators.arrayPatternMatch - rhs.code shouldBe s"${RubyOperators.arrayPatternMatch}(result)" + rhs.code shouldBe s"${RubyOperators.arrayPatternMatch}()" case xs => fail(s"Expected lhs and rhs, got [${xs.code.mkString(",")}]") } @@ -211,7 +211,7 @@ class CaseTests extends RubyCode2CpgFixture { lhs.name shouldBe "notResult" rhs.methodFullName shouldBe RubyOperators.arrayPatternMatch - rhs.code shouldBe s"${RubyOperators.arrayPatternMatch}(notResult)" + rhs.code shouldBe s"${RubyOperators.arrayPatternMatch}()" case xs => fail(s"Expected lhs and rhs, got [${xs.code.mkString(",")}]") } case _ => fail(s"Expected two true branches")