Skip to content

Commit

Permalink
[ruby] Constructor Lowering Fix (#4845)
Browse files Browse the repository at this point in the history
* [ruby] Constructor Lowering Fix
* Fixed constructor lowering structure as per #4822
* Tested and fixed issue in parenthesis-less `new` member calls (could use a better long-term fix in the parser)

* Testing and handling an expression base of constructor

* Fixed some imports

* Using `code(node)`
  • Loading branch information
DavidBakerEffendi authored Aug 13, 2024
1 parent 00c911f commit 10a089e
Show file tree
Hide file tree
Showing 7 changed files with 119 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -266,16 +266,17 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) {
}

protected def astForObjectInstantiation(node: RubyNode & ObjectInstantiation): Ast = {
val className = node.target.text
val callName = "new"
val methodName = Defines.Initialize
/*
We short-cut the call edge from `new` call to `initialize` method, however we keep the modelling of the receiver
as referring to the singleton class.
*/
val (receiverTypeFullName, fullName) = scope.tryResolveTypeReference(className) match {
case Some(typeMetaData) => s"${typeMetaData.name}<class>" -> s"${typeMetaData.name}.$methodName"
case None => XDefines.Any -> XDefines.DynamicCallUnknownFullName
val (receiverTypeFullName, fullName) = node.target match {
case x: (SimpleIdentifier | MemberAccess) =>
scope.tryResolveTypeReference(x.text) match {
case Some(typeMetaData) => s"${typeMetaData.name}<class>" -> s"${typeMetaData.name}.${Defines.Initialize}"
case None => XDefines.Any -> XDefines.DynamicCallUnknownFullName
}
case _ => XDefines.Any -> XDefines.DynamicCallUnknownFullName
}
/*
Similarly to some other frontends, we lower the constructor into two operations, e.g.,
Expand All @@ -287,7 +288,7 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) {

val tmpName = tmpGen.fresh
val tmpTypeHint = receiverTypeFullName.stripSuffix("<class>")
val tmp = SimpleIdentifier(Option(className))(node.span.spanStart(tmpName))
val tmp = SimpleIdentifier(None)(node.span.spanStart(tmpName))
val tmpLocal = NewLocal().name(tmpName).code(tmpName).dynamicTypeHintFullName(Seq(tmpTypeHint))
scope.addToScope(tmpName, tmpLocal)

Expand All @@ -298,12 +299,11 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) {
}

// Assign tmp to <alloc>
val receiverAst = Ast(identifierNode(node, className, className, receiverTypeFullName))
val allocCall = callNode(node, code(node), Operators.alloc, Operators.alloc, DispatchTypes.STATIC_DISPATCH)
val allocAst = callAst(allocCall, Seq.empty, Option(receiverAst))
val allocCall = callNode(node, code(node), Operators.alloc, Operators.alloc, DispatchTypes.STATIC_DISPATCH)
val allocAst = callAst(allocCall, Seq.empty)
val assignmentCall = callNode(
node,
s"${tmp.text} = ${code(node)}",
s"${tmp.text} = ${code(node.target)}.${Defines.Initialize}",
Operators.assignment,
Operators.assignment,
DispatchTypes.STATIC_DISPATCH
Expand All @@ -318,8 +318,11 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) {
x.arguments.map(astForMethodCallArgument) :+ typeRef
}

val constructorCall = callNode(node, code(node), callName, fullName, DispatchTypes.DYNAMIC_DISPATCH)
val constructorCallAst = callAst(constructorCall, argumentAsts, Option(tmpIdentifier))
val constructorCall =
callNode(node, code(node), Defines.Initialize, Defines.Any, DispatchTypes.DYNAMIC_DISPATCH)
if fullName != XDefines.DynamicCallUnknownFullName then constructorCall.dynamicTypeHintFullName(Seq(fullName))
val constructorRecv = astForExpression(MemberAccess(node.target, ".", Defines.Initialize)(node.span))
val constructorCallAst = callAst(constructorCall, argumentAsts, Option(tmpIdentifier), Option(constructorRecv))
val retIdentifierAst = tmpIdentifier
scope.popScope()

Expand Down Expand Up @@ -864,8 +867,9 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) {

protected def astForFieldAccess(node: MemberAccess, stripLeadingAt: Boolean = false): Ast = {
val (memberName, memberCode) = node.target match {
case _ if stripLeadingAt => node.memberName -> node.memberName.stripPrefix("@")
case _: TypeIdentifier => node.memberName -> node.memberName
case _ if node.memberName == Defines.Initialize => Defines.Initialize -> Defines.Initialize
case _ if stripLeadingAt => node.memberName -> node.memberName.stripPrefix("@")
case _: TypeIdentifier => node.memberName -> node.memberName
case _ if !node.memberName.startsWith("@") && node.memberName.headOption.exists(_.isLower) =>
s"@${node.memberName}" -> node.memberName
case _ => node.memberName -> node.memberName
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -692,10 +692,22 @@ class RubyNodeCreator extends RubyParserBaseVisitor[RubyNode] {
}

override def visitMemberAccessCommand(ctx: RubyParser.MemberAccessCommandContext): RubyNode = {
val args = ctx.commandArgument.arguments.map(visit)
val methodName = visit(ctx.methodName())
val base = visit(ctx.primary())
MemberCall(base, ".", methodName.text, args)(ctx.toTextSpan)
val args = ctx.commandArgument.arguments.map(visit)
val base = visit(ctx.primary())

if (ctx.methodName().getText == "new") {
base match {
case SingleAssignment(lhs, op, rhs) =>
// fixme: Parser packaging arguments from a parenthesis-less object instantiation is odd
val assignSpan = base.span.spanStart(s"${base.span.text}.new")
val rhsSpan = rhs.span.spanStart(s"${rhs.span.text}.new")
SingleAssignment(lhs, op, SimpleObjectInstantiation(rhs, args)(rhsSpan))(assignSpan)
case _ => SimpleObjectInstantiation(base, args)(ctx.toTextSpan)
}
} else {
val methodName = visit(ctx.methodName())
MemberCall(base, ".", methodName.text, args)(ctx.toTextSpan)
}
}

override def visitConstantIdentifierVariable(ctx: RubyParser.ConstantIdentifierVariableContext): RubyNode = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import io.shiftleft.codepropertygraph.generated.nodes.{Call, Literal}
import io.shiftleft.semanticcpg.language.*
import io.joern.rubysrc2cpg.passes.Defines
import io.joern.x2cpg.Defines as XDefines
import io.shiftleft.codepropertygraph.generated.nodes.Literal

class ArrayTests extends RubyCode2CpgFixture {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import io.joern.rubysrc2cpg.passes.GlobalTypes.kernelPrefix
import io.joern.rubysrc2cpg.testfixtures.RubyCode2CpgFixture
import io.joern.x2cpg.Defines
import io.shiftleft.codepropertygraph.generated.nodes.*
import io.shiftleft.codepropertygraph.generated.{DispatchTypes, Operators}
import io.shiftleft.codepropertygraph.generated.{DispatchTypes, NodeTypes, Operators}
import io.shiftleft.semanticcpg.language.*

class CallTests extends RubyCode2CpgFixture(withPostProcessing = true) {
Expand Down Expand Up @@ -155,13 +155,15 @@ class CallTests extends RubyCode2CpgFixture(withPostProcessing = true) {
"a simple object instantiation" should {

val cpg = code("""class A
| def initialize(a, b)
| end
|end
|
|a = A.new
|a = A.new 1, 2
|""".stripMargin)

"create an assignment from `a` to an <init> invocation block" in {
inside(cpg.method.isModule.assignment.where(_.target.isIdentifier.name("a")).l) {
"create an assignment from `a` to an alloc lowering invocation block" in {
inside(cpg.method.isModule.assignment.and(_.target.isIdentifier.name("a"), _.source.isBlock).l) {
case assignment :: Nil =>
assignment.code shouldBe "a = A.new"
inside(assignment.argument.l) {
Expand All @@ -174,7 +176,7 @@ class CallTests extends RubyCode2CpgFixture(withPostProcessing = true) {
}
}

"create an assignment from a temp variable to the <init> call" in {
"create an assignment from a temp variable to the alloc call" in {
inside(cpg.method.isModule.assignment.where(_.target.isIdentifier.name("<tmp-0>")).l) {
case assignment :: Nil =>
inside(assignment.argument.l) {
Expand All @@ -184,22 +186,76 @@ class CallTests extends RubyCode2CpgFixture(withPostProcessing = true) {
alloc.name shouldBe Operators.alloc
alloc.methodFullName shouldBe Operators.alloc
alloc.code shouldBe "A.new"
alloc.argument.size shouldBe 0
case xs => fail(s"Expected one identifier and one call argument, got [${xs.code.mkString(",")}]")
}
case xs => fail(s"Expected a single assignment, got [${xs.code.mkString(",")}]")
}
}

"create a call to the object's constructor, with the temp variable receiver" in {
inside(cpg.call.nameExact("new").l) {
inside(cpg.call.nameExact(RubyDefines.Initialize).l) {
case constructor :: Nil =>
inside(constructor.argument.l) {
case (a: Identifier) :: Nil =>
case (a: Identifier) :: (one: Literal) :: (two: Literal) :: Nil =>
a.name shouldBe "<tmp-0>"
a.typeFullName shouldBe s"Test0.rb:$Main.A"
a.argumentIndex shouldBe 0

one.code shouldBe "1"
two.code shouldBe "2"
case xs => fail(s"Expected one identifier and one call argument, got [${xs.code.mkString(",")}]")
}

val recv = constructor.receiver.head.asInstanceOf[Call]
recv.methodFullName shouldBe Operators.fieldAccess
recv.name shouldBe Operators.fieldAccess
recv.code shouldBe s"A.${RubyDefines.Initialize}"

recv.argument(1).label shouldBe NodeTypes.CALL
recv.argument(1).code shouldBe "self.A"
recv.argument(2).label shouldBe NodeTypes.FIELD_IDENTIFIER
recv.argument(2).code shouldBe RubyDefines.Initialize
case xs => fail(s"Expected a single alloc, got [${xs.code.mkString(",")}]")
}
}
}

"an object instantiation from some expression" should {
val cpg = code("""def foo
| params[:type].constantize.new(path)
|end
|""".stripMargin)

"create a call node on the receiver end of the constructor lowering" in {
inside(cpg.call.nameExact(RubyDefines.Initialize).l) {
case constructor :: Nil =>
inside(constructor.argument.l) {
case (a: Identifier) :: (selfPath: Call) :: Nil =>
a.name shouldBe "<tmp-0>"
a.typeFullName shouldBe Defines.Any
a.argumentIndex shouldBe 0

selfPath.code shouldBe "self.path"
case xs => fail(s"Expected one identifier and one call argument, got [${xs.code.mkString(",")}]")
}

val recv = constructor.receiver.head.asInstanceOf[Call]
recv.methodFullName shouldBe Operators.fieldAccess
recv.name shouldBe Operators.fieldAccess
recv.code shouldBe s"params[:type].constantize.${RubyDefines.Initialize}"

inside(recv.argument.l) { case (constantize: Call) :: (initialize: FieldIdentifier) :: Nil =>
constantize.code shouldBe "params[:type].constantize"
inside(constantize.argument.l) { case (indexAccess: Call) :: (const: FieldIdentifier) :: Nil =>
indexAccess.name shouldBe Operators.indexAccess
indexAccess.code shouldBe "params[:type]"

const.canonicalName shouldBe "constantize"
}

initialize.canonicalName shouldBe RubyDefines.Initialize
}
case xs => fail(s"Expected a single alloc, got [${xs.code.mkString(",")}]")
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,9 @@ class DownloadDependencyTest extends RubyCode2CpgFixture(downloadDependencies =
case (v: Identifier) :: (block: Block) :: Nil =>
v.dynamicTypeHintFullName should contain("dummy_logger.Main_module.Main_outer_class")

inside(block.astChildren.isCall.nameExact("new").headOption) {
inside(block.astChildren.isCall.nameExact(RubyDefines.Initialize).headOption) {
case Some(constructorCall) =>
constructorCall.methodFullName shouldBe s"dummy_logger.Main_module.Main_outer_class.${RubyDefines.Initialize}"
constructorCall.methodFullName shouldBe Defines.Any
case None => fail(s"Expected constructor call, did not find one")
}
case xs => fail(s"Expected two arguments under the constructor assignment, got [${xs.code.mkString(", ")}]")
Expand All @@ -109,9 +109,9 @@ class DownloadDependencyTest extends RubyCode2CpgFixture(downloadDependencies =
case (g: Identifier) :: (block: Block) :: Nil =>
g.dynamicTypeHintFullName should contain("dummy_logger.Help")

inside(block.astChildren.isCall.name("new").headOption) {
inside(block.astChildren.isCall.name(RubyDefines.Initialize).headOption) {
case Some(constructorCall) =>
constructorCall.methodFullName shouldBe s"dummy_logger.Help.${RubyDefines.Initialize}"
constructorCall.methodFullName shouldBe Defines.Any
case None => fail(s"Expected constructor call, did not find one")
}
case xs => fail(s"Expected two arguments under the constructor assignment, got [${xs.code.mkString(", ")}]")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package io.joern.rubysrc2cpg.querying

import io.joern.rubysrc2cpg.passes.GlobalTypes.builtinPrefix
import io.joern.rubysrc2cpg.passes.Defines.Main
import io.joern.rubysrc2cpg.passes.Defines.{Main, Initialize}
import io.joern.rubysrc2cpg.testfixtures.RubyCode2CpgFixture
import io.joern.x2cpg.Defines
import io.shiftleft.codepropertygraph.generated.nodes.*
Expand Down Expand Up @@ -267,10 +267,11 @@ class DoBlockTests extends RubyCode2CpgFixture {
inside(constrBlock.astChildren.l) {
case (tmpLocal: Local) :: (tmpAssign: Call) :: (newCall: Call) :: (_: Identifier) :: Nil =>
tmpLocal.name shouldBe "<tmp-0>"
tmpAssign.code shouldBe "<tmp-0> = Array.new(x) { |i| i += 1 }"
tmpAssign.code shouldBe s"<tmp-0> = Array.$Initialize"

newCall.name shouldBe "new"
newCall.methodFullName shouldBe s"$builtinPrefix.Array.initialize"
newCall.name shouldBe Initialize
newCall.methodFullName shouldBe Defines.Any
newCall.dynamicTypeHintFullName should contain(s"$builtinPrefix.Array.$Initialize")

inside(newCall.argument.l) {
case (_: Identifier) :: (x: Identifier) :: (closure: TypeRef) :: Nil =>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
package io.joern.rubysrc2cpg.querying

import io.joern.rubysrc2cpg.passes.Defines
import io.joern.rubysrc2cpg.passes.Defines.Main
import io.joern.rubysrc2cpg.passes.Defines.{Initialize, Main}
import io.joern.rubysrc2cpg.passes.GlobalTypes.{builtinPrefix, kernelPrefix}
import io.joern.rubysrc2cpg.testfixtures.RubyCode2CpgFixture
import io.shiftleft.codepropertygraph.generated.DispatchTypes
import io.shiftleft.codepropertygraph.generated.nodes.Literal
import io.shiftleft.codepropertygraph.generated.{DispatchTypes, NodeTypes}
import io.shiftleft.semanticcpg.language.*
import io.shiftleft.semanticcpg.language.operatorextension.OpNodes.FieldAccess
import org.scalatest.Inspectors

class ImportTests extends RubyCode2CpgFixture(withPostProcessing = true) with Inspectors {
Expand Down Expand Up @@ -62,7 +61,13 @@ class ImportTests extends RubyCode2CpgFixture(withPostProcessing = true) with In
)

val List(newCall) =
cpg.method.isModule.filename("t1.rb").ast.isCall.methodFullName(".*\\.initialize").methodFullName.l
cpg.method.isModule
.filename("t1.rb")
.ast
.isCall
.dynamicTypeHintFullName
.filter(x => x.startsWith(path) && x.endsWith(Initialize))
.l
newCall should startWith(s"$path.rb:")
}
}
Expand Down Expand Up @@ -285,12 +290,13 @@ class ImportTests extends RubyCode2CpgFixture(withPostProcessing = true) with In

"resolve calls to builtin functions" in {
inside(cpg.call.methodFullName("(pp|csv).*").l) {
case csvParseCall :: csvTableInitCall :: ppCall :: Nil =>
case csvParseCall :: ppCall :: Nil =>
csvParseCall.methodFullName shouldBe "csv.CSV.parse"
ppCall.methodFullName shouldBe "pp.PP.pp"
csvTableInitCall.methodFullName shouldBe "csv.CSV.Table.initialize"
case xs => fail(s"Expected three calls, got [${xs.code.mkString(",")}] instead")
}

cpg.call(Initialize).dynamicTypeHintFullName.toSet should contain("csv.CSV.Table.initialize")
}
}

Expand Down

0 comments on commit 10a089e

Please sign in to comment.