Skip to content

Commit

Permalink
[C#] Ast for switch and label statements (#4095)
Browse files Browse the repository at this point in the history
This PR includes, 
1. AST Creation for `Switch` statements.
2. Corresponding tests. 
3. Upgrading the `DotNetAstGen` version to allow new properties.

Resolves #3984
  • Loading branch information
karan-batavia authored Jan 29, 2024
1 parent f481777 commit 06ecfb0
Show file tree
Hide file tree
Showing 4 changed files with 142 additions and 2 deletions.
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
csharpsrc2cpg {
dotnetastgen_version: "0.16.0"
dotnetastgen_version: "0.18.0"
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,17 @@ import io.circe.Json
import io.joern.csharpsrc2cpg.parser.DotNetJsonAst.{
BinaryExpr,
Block,
CasePatternSwitchLabel,
CaseSwitchLabel,
DefaultSwitchLabel,
DoStatement,
ExpressionStatement,
ForEachStatement,
ForStatement,
GlobalStatement,
IfStatement,
LiteralExpr,
SwitchStatement,
ThrowStatement,
TryStatement,
UnaryExpr,
Expand All @@ -21,6 +25,8 @@ import io.joern.x2cpg.{Ast, ValidationMode}
import io.shiftleft.codepropertygraph.generated.ControlStructureTypes
import io.shiftleft.codepropertygraph.generated.nodes.ControlStructure

import scala.::
import scala.collection.mutable.ArrayBuffer
import scala.util.{Failure, Success, Try}

trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { this: AstCreator =>
Expand Down Expand Up @@ -55,9 +61,45 @@ trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { t
case ForStatement => astForForStatement(nodeInfo)
case DoStatement => astForDoStatement(nodeInfo)
case WhileStatement => astForWhileStatement(nodeInfo)
case SwitchStatement => astForSwitchStatement(nodeInfo)
case _ => notHandledYet(nodeInfo)
}

private def astForSwitchLabel(labelNode: DotNetNodeInfo): Seq[Ast] = {
val caseNode = jumpTargetNode(labelNode, "case", labelNode.code)
labelNode.node match
case CasePatternSwitchLabel =>
val patternNode = createDotNetNodeInfo(labelNode.json(ParserKeys.Pattern)(ParserKeys.Expression))
Ast(caseNode) +: astForNode(patternNode)
case CaseSwitchLabel =>
val valueNode = createDotNetNodeInfo(labelNode.json(ParserKeys.Value))
Ast(caseNode) +: astForNode(valueNode)
case DefaultSwitchLabel => Seq(Ast(caseNode))
case _ => Seq(Ast())
}

private def astForSwitchStatement(switchStmt: DotNetNodeInfo): Seq[Ast] = {
val comparatorNode = createDotNetNodeInfo(switchStmt.json(ParserKeys.Expression))
val comparatorNodeAst = astForExpression(comparatorNode).headOption

val switchBodyAsts: Seq[Ast] = switchStmt
.json(ParserKeys.Sections)
.arr
.flatMap(section =>
val sectionNode = section match
case value: ujson.Obj => createDotNetNodeInfo(value)
case value: ujson.Value => nullSafeCreateParserNodeInfo(Option(value))

val labelNodes = sectionNode.json(ParserKeys.Labels).arr
labelNodes.flatMap(labelNode => astForSwitchLabel(createDotNetNodeInfo(labelNode))) :+ astForBlock(sectionNode)
)
.toSeq

val switchNode = controlStructureNode(switchStmt, ControlStructureTypes.SWITCH, s"switch (${comparatorNode.code})");

Seq(controlStructureAst(switchNode, comparatorNodeAst, switchBodyAsts))
}

private def astForWhileStatement(whileStmt: DotNetNodeInfo): Seq[Ast] = {
val whileBlock = createDotNetNodeInfo(whileStmt.json(ParserKeys.Statement))
val whileBlockAst = astForBlock(whileBlock)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ object DotNetJsonAst {
sealed trait BaseExpr extends DotNetParserNode
sealed trait BaseStmt extends DotNetParserNode

sealed trait BasePattern extends DotNetParserNode

sealed trait BaseLabel extends DotNetParserNode

object GlobalStatement extends BaseStmt
object ExpressionStatement extends BaseStmt

Expand Down Expand Up @@ -156,6 +160,20 @@ object DotNetJsonAst {

object WhileStatement extends BaseStmt

object SwitchStatement extends BaseStmt

object SwitchSection extends BaseExpr

object RelationalPattern extends BasePattern

object ConstantPattern extends BasePattern

object CaseSwitchLabel extends BaseLabel

object CasePatternSwitchLabel extends BaseLabel

object DefaultSwitchLabel extends BaseLabel

object Unknown extends DotNetParserNode

}
Expand Down Expand Up @@ -185,6 +203,7 @@ object ParserKeys {
val Initializer = "Initializer"
val Keyword = "Keyword"
val Kind = "Kind"
val Labels = "Labels"
val Left = "Left"
val LineStart = "LineStart"
val LineEnd = "LineEnd"
Expand All @@ -196,6 +215,8 @@ object ParserKeys {
val OperatorToken = "OperatorToken"
val Parameters = "Parameters"
val ParameterList = "ParameterList"
val Pattern = "Pattern"
val Sections = "Sections"
val Statement = "Statement"
val Statements = "Statements"
val ReturnType = "ReturnType"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package io.joern.csharpsrc2cpg.querying.ast
import io.joern.csharpsrc2cpg.CSharpOperators
import io.joern.csharpsrc2cpg.testfixtures.CSharpCode2CpgFixture
import io.joern.x2cpg.Defines
import io.shiftleft.codepropertygraph.generated.nodes.{Block, Call, Local}
import io.shiftleft.codepropertygraph.generated.nodes.{Block, Call, JumpTarget, Local}
import io.shiftleft.codepropertygraph.generated.{ControlStructureTypes, DispatchTypes}
import io.shiftleft.semanticcpg.language.*

Expand Down Expand Up @@ -106,4 +106,81 @@ class ControlStructureTests extends CSharpCode2CpgFixture {

}

"the swtich statement" should {
val cpg = code(basicBoilerplate("""
|switch (i) {
| case > 0:
| i++;
| break;
| case < 0:
| i--;
| break;
| default:
| i += 10;
| break;
|}
|""".stripMargin))

"create a control structure node and contain correct astChildren" in {
inside(cpg.method("Main").controlStructure.controlStructureTypeExact(ControlStructureTypes.SWITCH).l) {
case switchNode :: Nil =>
switchNode.code shouldBe "switch (i)";
switchNode.controlStructureType shouldBe ControlStructureTypes.SWITCH

inside(switchNode.astChildren.isBlock.l) { case case1 :: case2 :: case3 :: Nil =>
val List(incCall) = case1.astChildren.isCall.l;
incCall.code shouldBe "i++"

val List(decCall) = case2.astChildren.isCall.l;
decCall.code shouldBe "i--"

val List(plusEqualsCall) = case3.astChildren.isCall.l;
plusEqualsCall.code shouldBe "i += 10"
}

inside(switchNode.astChildren.collect { case j: JumpTarget => j }.l) {
case case1 :: case2 :: defaultCase :: Nil =>
case1.code shouldBe "case > 0:"
case2.code shouldBe "case < 0:"
defaultCase.code shouldBe "default:"
}
}
}
}

"switch statement with multiple labels" should {
val cpg = code(basicBoilerplate("""
|switch (i) {
| case > 0:
| case < 10:
| i++;
| break;
| case 10:
| i--;
| break;
| default:
| i += 10;
| break;
|}
|""".stripMargin))

"create a control structure node with correct label and case clauses" in {

inside(cpg.method("Main").controlStructure.controlStructureTypeExact(ControlStructureTypes.SWITCH).l) {
case switchNode :: Nil =>
switchNode.code shouldBe "switch (i)";
switchNode.controlStructureType shouldBe ControlStructureTypes.SWITCH

inside(switchNode.astChildren.collect { case j: JumpTarget => j }.l) {
case case1 :: case1_1 :: case2 :: defaultCase :: Nil =>
case1.code shouldBe "case > 0:"
case2.code shouldBe "case 10:"
case1_1.code shouldBe "case < 10:"
defaultCase.code shouldBe "default:"
}
}
}

}

}

0 comments on commit 06ecfb0

Please sign in to comment.