diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForControlStructuresCreator.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForControlStructuresCreator.scala index 38ebb01b72b7..08757227bc30 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForControlStructuresCreator.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForControlStructuresCreator.scala @@ -299,17 +299,13 @@ trait AstForControlStructuresCreator(implicit withSchemaValidation: ValidationMo case x: ArrayPattern => val condition = expr.map(e => BinaryExpression(x, "===", e)(x.span)).getOrElse(inClause.pattern) val body = inClause.body + val variables = x.children.collect { case x: MatchVariable => x } - val variables = x.children.collect { case x: MatchVariable => - x - } - - val conditionBody = if (variables.nonEmpty) { - StatementList(variables.map { x => - val lhs = SimpleIdentifier()(x.span) - SingleAssignment(lhs, "=", x)( + val conditionBody = if (variables.nonEmpty && expr.isDefined) { + StatementList(variables.map { lhs => + SingleAssignment(lhs, "=", MatchVariable()(expr.get.span))( inClause.span - .spanStart(s"${lhs.span.text} = ${RubyOperators.arrayPatternMatch}(${lhs.span.text})") + .spanStart(s"${lhs.span.text} = ${RubyOperators.arrayPatternMatch}(${expr.get.text})") ) } :+ body)(body.span) } else { @@ -317,7 +313,8 @@ trait AstForControlStructuresCreator(implicit withSchemaValidation: ValidationMo } (condition, conditionBody) - case x => (x, inClause.body) + case x => + (x, inClause.body) } val conditional = IfExpression( diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForExpressionsCreator.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForExpressionsCreator.scala index ecddf5fa03d4..44499e57e6a5 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForExpressionsCreator.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForExpressionsCreator.scala @@ -520,6 +520,11 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { val asts = astsForStatement(x.multipleAssignment) val call = callNode(node, code(node), op, op, DispatchTypes.STATIC_DISPATCH) return callAst(call, asts :+ rhsAst) + case x: MatchVariable => + handleVariableOccurrence(x.toSimpleIdentifier) // Create local variable under this scope + val matchIden = astForExpression(x.toSimpleIdentifier) + val call = callNode(node, code(node), op, op, DispatchTypes.STATIC_DISPATCH) + return callAst(call, matchIden :: rhsAst :: Nil) case _ => astForExpression(node.lhs) } @@ -618,7 +623,10 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { protected def astForArrayPattern(node: ArrayPattern): Ast = { val callNode_ = callNode(node, code(node), Operators.arrayInitializer, Operators.arrayInitializer, DispatchTypes.STATIC_DISPATCH) - val childrenAst = node.children.map(astForExpression) + val childrenAst = node.children.map { + case x: MatchVariable => astForExpression(SimpleIdentifier()(x.span)) + case x => astForExpression(x) + } callAst(callNode_, childrenAst) } diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/RubyIntermediateAst.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/RubyIntermediateAst.scala index 406893179b8e..ddbd417b900d 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/RubyIntermediateAst.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/RubyIntermediateAst.scala @@ -348,7 +348,9 @@ object RubyIntermediateAst { final case class ArrayPattern(children: List[RubyExpression])(span: TextSpan) extends RubyExpression(span) - final case class MatchVariable()(span: TextSpan) extends RubyExpression(span) + final case class MatchVariable()(span: TextSpan) extends RubyExpression(span) { + def toSimpleIdentifier: SimpleIdentifier = SimpleIdentifier()(span) + } final case class NextExpression()(span: TextSpan) extends RubyExpression(span) with ControlFlowStatement diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/CaseTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/CaseTests.scala index 6a05a907b06e..23bdcc2cdb06 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/CaseTests.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/CaseTests.scala @@ -202,7 +202,7 @@ class CaseTests extends RubyCode2CpgFixture { lhs.name shouldBe "result" rhs.methodFullName shouldBe RubyOperators.arrayPatternMatch - rhs.code shouldBe s"${RubyOperators.arrayPatternMatch}(result)" + rhs.code shouldBe s"${RubyOperators.arrayPatternMatch}()" case xs => fail(s"Expected lhs and rhs, got [${xs.code.mkString(",")}]") } @@ -211,7 +211,7 @@ class CaseTests extends RubyCode2CpgFixture { lhs.name shouldBe "notResult" rhs.methodFullName shouldBe RubyOperators.arrayPatternMatch - rhs.code shouldBe s"${RubyOperators.arrayPatternMatch}(notResult)" + rhs.code shouldBe s"${RubyOperators.arrayPatternMatch}()" case xs => fail(s"Expected lhs and rhs, got [${xs.code.mkString(",")}]") } case _ => fail(s"Expected two true branches")