From 21558c7c26249917993e54e71766129390a28a33 Mon Sep 17 00:00:00 2001 From: Reuben Steenekamp Date: Fri, 12 Jan 2024 16:12:55 +0200 Subject: [PATCH] [ruby] when Statements (Switch Statements) #3926 (#4029) This draft translates ruby `case` expressions into `if-elif-...-else-end` chains. The expression matched against is assigned a temporary variable `` where # is a number. I was not sure on the convention but I wanted to be sure the variable is always fresh. A single when can contain a list of expressions to match against, which I translate into an `or-expression` if there is more than one. When matching against an expression `mExpr`, it is turned into a condition with `mExpr.=== `. This list of expressions can contain a splat at the end, which we aren't handling yet. It also remains maybe to special case to a switch ast if all match expressions are literals. TextSpans for the generated intermediate ast are not sane and will need to be considered more carefully. --------- Co-authored-by: David Baker Effendi --- .../rubysrc2cpg/astcreation/AstCreator.scala | 1 + .../AstForExpressionsCreator.scala | 2 + .../astcreation/AstForStatementsCreator.scala | 46 ++++++++++ .../astcreation/AstForTypesCreator.scala | 2 +- .../astcreation/FreshVariableCreator.scala | 15 +++ .../astcreation/RubyIntermediateAst.scala | 17 +++- .../rubysrc2cpg/parser/RubyNodeCreator.scala | 35 +++++-- .../rubysrc2cpg/querying/CaseTests.scala | 92 +++++++++++++++++++ 8 files changed, 202 insertions(+), 8 deletions(-) create mode 100644 joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/FreshVariableCreator.scala create mode 100644 joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/CaseTests.scala diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstCreator.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstCreator.scala index e10231424685..ca0e2861ae0a 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstCreator.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstCreator.scala @@ -22,6 +22,7 @@ class AstCreator(protected val filename: String, parser: ResourceManagedParser, with AstForExpressionsCreator with AstForFunctionsCreator with AstForTypesCreator + with FreshVariableCreator with AstNodeBuilder[RubyNode, AstCreator] { protected val logger: Logger = LoggerFactory.getLogger(getClass) diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForExpressionsCreator.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForExpressionsCreator.scala index cdd18f4def51..f63cea68bdc3 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForExpressionsCreator.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForExpressionsCreator.scala @@ -7,6 +7,8 @@ import io.joern.x2cpg.{Ast, ValidationMode} import io.shiftleft.codepropertygraph.generated.nodes.{NewBlock, NewLiteral, NewControlStructure} import io.shiftleft.codepropertygraph.generated.{DispatchTypes, Operators, ControlStructureTypes} import io.shiftleft.semanticcpg.language.NodeOrdering.nodeList +import scala.collection.mutable +import io.joern.rubysrc2cpg.parser.RubyParser trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { this: AstCreator => diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForStatementsCreator.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForStatementsCreator.scala index 5b8f765366ac..e9f815e7775a 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForStatementsCreator.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForStatementsCreator.scala @@ -14,6 +14,7 @@ trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { t case node: UntilExpression => astForUntilStatement(node) :: Nil case node: IfExpression => astForIfStatement(node) :: Nil case node: UnlessExpression => astForUnlessStatement(node) :: Nil + case node: CaseExpression => astsForCaseExpression(node) case node: StatementList => astForStatementList(node) :: Nil case node: SimpleCallWithBlock => astForSimpleCallWithBlock(node) :: Nil case node: MemberCallWithBlock => astForMemberCallWithBlock(node) :: Nil @@ -100,6 +101,51 @@ trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { t controlStructureAst(ifNode, Some(notConditionAst), thenAst :: elseAsts) } + protected def astsForCaseExpression(node: CaseExpression): Seq[Ast] = { + def goCase(expr: Option[SimpleIdentifier]): List[RubyNode] = { + val elseThenClause: Option[RubyNode] = node.elseClause.map(_.asInstanceOf[ElseClause].thenClause) + val whenClauses = node.whenClauses.map(_.asInstanceOf[WhenClause]) + val ifElseChain = whenClauses.foldRight[Option[RubyNode]](elseThenClause) { + (whenClause: WhenClause, restClause: Option[RubyNode]) => + // We translate multiple match expressions into an or expression. + // There may be a splat as the last match expression, which is currently parsed as unknown + // A single match expression is compared using `.===` to the case target expression if it is present + // otherwise it is treated as a conditional. + val conditions = whenClause.matchExpressions.map { mExpr => + expr.map(e => MemberCall(mExpr, ".", "===", List(e))(mExpr.span)).getOrElse(mExpr) + } ++ (whenClause.matchSplatExpression.iterator.flatMap { + case u: Unknown => List(u) + case e => + logger.warn("Splatting not implemented for `when` in ruby `case`") + List(Unknown()(e.span)) + }) + // There is always at least one match expression or a splat + // a splat will become an unknown in condition at the end + val condition = conditions.init.foldRight(conditions.last) { (cond, condAcc) => + BinaryExpression(cond, "||", condAcc)(whenClause.span) + } + val conditional = IfExpression( + condition, + whenClause.thenClause.asStatementList, + List(), + restClause.map { els => ElseClause(els.asStatementList)(els.span) } + )(node.span) + Some(conditional) + } + ifElseChain.iterator.toList + } + def generatedNode: StatementList = node.expression + .map { e => + val tmp = SimpleIdentifier(None)(e.span.spanStart(freshName)) + StatementList( + List(SingleAssignment(tmp, "=", e)(e.span)) ++ + goCase(Some(tmp)) + )(node.span) + } + .getOrElse(StatementList(goCase(None))(node.span)) + astsForStatement(generatedNode) + } + protected def astForStatementList(node: StatementList): Ast = { val block = blockNode(node) scope.pushNewScope(block) diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForTypesCreator.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForTypesCreator.scala index 409aa102f2d2..f86bd73c7a0e 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForTypesCreator.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForTypesCreator.scala @@ -57,7 +57,7 @@ trait AstForTypesCreator(implicit withSchemaValidation: ValidationMode) { this: node.body.asInstanceOf[StatementList] // for now (bodyStatement is a superset of stmtList) val classBodyAsts = classBody.statements.flatMap(astsForStatement) match { case bodyAsts if shouldGenerateDefaultConstructorStack.head => - val bodyStart = classBody.span.spanStart + val bodyStart = classBody.span.spanStart() val initBody = StatementList(List())(bodyStart) val methodDecl = astForMethodDeclaration(MethodDeclaration("", List(), initBody)(bodyStart)) methodDecl :: bodyAsts diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/FreshVariableCreator.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/FreshVariableCreator.scala new file mode 100644 index 000000000000..8a7c7bb9a24f --- /dev/null +++ b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/FreshVariableCreator.scala @@ -0,0 +1,15 @@ +package io.joern.rubysrc2cpg.astcreation + +import io.joern.rubysrc2cpg.astcreation.AstCreator +import io.joern.rubysrc2cpg.astcreation.RubyIntermediateAst.* + +trait FreshVariableCreator { this: AstCreator => + // This is in a single-threaded context. + var tmpCounter: Int = 0 + private def tmpTemplate(id: Int): String = s"" + protected def freshName: String = { + val name = tmpTemplate(tmpCounter) + tmpCounter += 1 + name + } +} diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/RubyIntermediateAst.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/RubyIntermediateAst.scala index e7d57911050a..fab2e4354873 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/RubyIntermediateAst.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/RubyIntermediateAst.scala @@ -9,7 +9,7 @@ object RubyIntermediateAst { columnEnd: Option[Integer], text: String ) { - def spanStart: TextSpan = TextSpan(line, column, line, column, "") + def spanStart(newText: String = ""): TextSpan = TextSpan(line, column, line, column, newText) } sealed class RubyNode(val span: TextSpan) { @@ -28,6 +28,7 @@ object RubyIntermediateAst { def asStatementList = node match case stmtList: StatementList => stmtList case _ => StatementList(List(node))(node.span) + } final case class Unknown()(span: TextSpan) extends RubyNode(span) @@ -119,6 +120,20 @@ object RubyIntermediateAst { span: TextSpan ) extends RubyNode(span) + final case class CaseExpression( + expression: Option[RubyNode], + whenClauses: List[RubyNode], + elseClause: Option[RubyNode] + )(span: TextSpan) + extends RubyNode(span) + + final case class WhenClause( + matchExpressions: List[RubyNode], + matchSplatExpression: Option[RubyNode], + thenClause: RubyNode + )(span: TextSpan) + extends RubyNode(span) + final case class ReturnExpression(expressions: List[RubyNode])(span: TextSpan) extends RubyNode(span) /** Represents an unqualified identifier e.g. `X`, `x`, `@x`, `@@x`, `$x`, `$<`, etc. */ diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/parser/RubyNodeCreator.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/parser/RubyNodeCreator.scala index 44a2b1d3004f..fc1447a53ef4 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/parser/RubyNodeCreator.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/parser/RubyNodeCreator.scala @@ -7,10 +7,8 @@ import io.joern.rubysrc2cpg.passes.Defines.getBuiltInType import org.antlr.v4.runtime.tree.{ErrorNode, ParseTree, TerminalNode} import scala.jdk.CollectionConverters.* -import io.joern.rubysrc2cpg.parser.RubyParser.RescueClauseContext -import io.joern.rubysrc2cpg.parser.RubyParser.EnsureClauseContext -import io.joern.rubysrc2cpg.parser.RubyParser.ExceptionClassListContext import org.antlr.v4.runtime.tree.RuleNode +import io.joern.rubysrc2cpg.parser.RubyParser.SplattingArgumentContext /** Converts an ANTLR Ruby Parse Tree into the intermediate Ruby AST. */ @@ -555,22 +553,47 @@ class RubyNodeCreator extends RubyParserBaseVisitor[RubyNode] { } } - override def visitExceptionClassList(ctx: ExceptionClassListContext): RubyNode = { + override def visitExceptionClassList(ctx: RubyParser.ExceptionClassListContext): RubyNode = { // Requires implementing multiple rhs with splatting Unknown()(ctx.toTextSpan) } - override def visitRescueClause(ctx: RescueClauseContext): RubyNode = { + override def visitRescueClause(ctx: RubyParser.RescueClauseContext): RubyNode = { val exceptionClassList = Option(ctx.exceptionClassList).map(visit) val elseClause = Option(ctx.exceptionVariableAssignment).map(visit) val thenClause = visit(ctx.thenClause) RescueClause(exceptionClassList, elseClause, thenClause)(ctx.toTextSpan) } - override def visitEnsureClause(ctx: EnsureClauseContext): RubyNode = { + override def visitEnsureClause(ctx: RubyParser.EnsureClauseContext): RubyNode = { EnsureClause(visit(ctx.compoundStatement()))(ctx.toTextSpan) } + override def visitCaseWithExpression(ctx: RubyParser.CaseWithExpressionContext): RubyNode = { + val expression = Option(ctx.commandOrPrimaryValue()).map(visit) + val whenClauses = Option(ctx.whenClause().asScala).fold(List())(_.map(visit).toList) + val elseClause = Option(ctx.elseClause()).map(visit) + CaseExpression(expression, whenClauses, elseClause)(ctx.toTextSpan) + } + + override def visitCaseWithoutExpression(ctx: RubyParser.CaseWithoutExpressionContext): RubyNode = { + val expression = None + val whenClauses = Option(ctx.whenClause().asScala).fold(List())(_.map(visit).toList) + val elseClause = Option(ctx.elseClause()).map(visit) + CaseExpression(expression, whenClauses, elseClause)(ctx.toTextSpan) + } + + override def visitWhenClause(ctx: RubyParser.WhenClauseContext): RubyNode = { + val whenArgs = ctx.whenArgument() + val matchArgs = + Option(whenArgs.operatorExpressionList()).iterator.flatMap(_.operatorExpression().asScala).map(visit).toList + val matchSplatArg = Option(whenArgs.splattingArgument()).map(visit) + val thenClause = visit(ctx.thenClause()) + WhenClause(matchArgs, matchSplatArg, thenClause)(ctx.toTextSpan) + } + + override def visitSplattingArgument(ctx: SplattingArgumentContext): RubyNode = Unknown()(ctx.toTextSpan) + override def visitAssociationKey(ctx: RubyParser.AssociationKeyContext): RubyNode = { if (Option(ctx.operatorExpression()).isDefined) { visit(ctx.operatorExpression()) diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/CaseTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/CaseTests.scala new file mode 100644 index 000000000000..ed00a238a5ab --- /dev/null +++ b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/CaseTests.scala @@ -0,0 +1,92 @@ +package io.joern.rubysrc2cpg.querying + +import io.joern.rubysrc2cpg.testfixtures.RubyCode2CpgFixture +import io.shiftleft.semanticcpg.language.* +import io.shiftleft.codepropertygraph.generated.nodes.* +import io.shiftleft.codepropertygraph.generated.Operators + +class CaseTests extends RubyCode2CpgFixture { + "`case x ... end` should be represented with if-else chain and multiple match expressions should be or-ed together" in { + val caseCode = """ + |case 0 + | when 0 + | 0 + | when 1,2 then 1 + | when 3, *[4,5] then 2 + | when *[6] then 3 + | else 4 + |end + |""".stripMargin + val cpg = code(caseCode) + + val block @ List(_) = cpg.method(":program").block.astChildren.isBlock.l + + val List(assign) = block.astChildren.assignment.l; + val List(lhs, rhs) = assign.argument.l + + List(lhs).isCall.name.l shouldBe List("") + List(rhs).isLiteral.code.l shouldBe List("0") + + val headIf @ List(_) = block.astChildren.isControlStructure.l + val ifStmts @ List(_, _, _, _) = headIf.repeat(_.astChildren.order(3).astChildren.isControlStructure)(_.emit).l; + val conds: List[List[String]] = ifStmts.condition.map { cond => + val orConds = List(cond) + .repeat(_.isCall.where(_.name(Operators.logicalOr)).argument)( + _.emit(_.whereNot(_.isCall.name(Operators.logicalOr))) + ) + .l + orConds.map { + case u: Unknown => "unknown" + case mExpr => + val call @ List(_) = List(mExpr).isCall.l + call.methodFullName.l shouldBe List("===") + val List(lhs, rhs) = call.argument.l + rhs.code shouldBe "" + val List(code) = List(lhs).isCall.argument(1).code.l + code + }.l + }.l + + conds shouldBe List(List("0"), List("1", "2"), List("3", "unknown"), List("unknown")) + val matchResults = ifStmts.astChildren.order(2).astChildren ++ ifStmts.last.astChildren.order(3).astChildren + matchResults.code.l shouldBe List("0", "1", "2", "3", "4") + + // It's not ideal, but we choose the smallest containing text span that we have easily acesssible + // as we don't have a good way to immutably update RubyNode text spans. + ifStmts.code.l should contain only caseCode.trim + ifStmts.condition.map(_.code.trim).l shouldBe List("0", "when 1,2 then 1", "when 3, *[4,5] then 2", "*[6]") + } + + "`case ... end` without expression" in { + val cpg = code(""" + |case + | when false, true then 0 + | when true then 1 + | when false, *[false,false] then 2 + | when *[false, true] then 3 + |end + |""".stripMargin) + + val block @ List(_) = cpg.method(":program").block.astChildren.isBlock.l + + val headIf @ List(_) = block.astChildren.isControlStructure.l + val ifStmts @ List(_, _, _, _) = headIf.repeat(_.astChildren.order(3).astChildren.isControlStructure)(_.emit).l; + val conds: List[List[String]] = ifStmts.condition.map { cond => + val orConds = List(cond) + .repeat(_.isCall.where(_.name(Operators.logicalOr)).argument)( + _.emit(_.whereNot(_.isCall.name(Operators.logicalOr))) + ) + .l + orConds.map { + case u: Unknown => "unknown" + case c => c.code + } + }.l + conds shouldBe List(List("false", "true"), List("true"), List("false", "unknown"), List("unknown")) + + val matchResults = ifStmts.astChildren.order(2).astChildren.l + matchResults.code.l shouldBe List("0", "1", "2", "3") + + ifStmts.last.astChildren.order(3).l shouldBe List() + } +}