Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[c2cpg] Improvements for range-based for-statement and local code fields #5215

Merged
merged 1 commit into from
Jan 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -142,14 +142,16 @@ trait AstCreatorHelper(implicit withSchemaValidation: ValidationMode) { this: As
"class"
)

private val KeepTypeKeywords: List[String] = List("unsigned", "volatile")
private val KeepTypeKeywords: List[String] = List("unsigned", "volatile", "const", "static")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd rename this, since static is decidedly not part of the type. (I think it's called a storage modifier.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, it's a list of keywords to keep at the string representation of a type. I keep it that way for now.


protected def cleanType(rawType: String, stripKeywords: Boolean = true): String = {
if (rawType == Defines.Any) return rawType
val tpe =
if (stripKeywords) {
ReservedTypeKeywords.foldLeft(rawType) { (cur, repl) =>
if (cur.contains(s"$repl ")) cur.replace(s"$repl ", "") else cur
if (cur.startsWith(s"$repl ") || cur.contains(s" $repl ")) {
cur.replace(s" $repl ", " ").replace(s"$repl ", "")
} else cur
}
} else {
rawType
Expand All @@ -168,17 +170,23 @@ trait AstCreatorHelper(implicit withSchemaValidation: ValidationMode) { this: As
case t if t.contains(Defines.QualifiedNameSeparator) => replaceWhitespaceAfterTypeKeyword(fixQualifiedName(t))
case t if KeepTypeKeywords.exists(k => t.startsWith(s"$k ")) => replaceWhitespaceAfterTypeKeyword(t)
case t if t.contains("[") && t.contains("]") => replaceWhitespaceAfterTypeKeyword(t)
case t if t.contains("<") && t.contains(">") => replaceWhitespaceAfterTypeKeyword(t)
case t if t.contains("*") => replaceWhitespaceAfterTypeKeyword(t)
case someType => someType
}
}

private def replaceWhitespaceAfterTypeKeyword(tpe: String): String = {
if (KeepTypeKeywords.exists(k => tpe.startsWith(s"$k "))) {
if (KeepTypeKeywords.exists(k => tpe.startsWith(s"$k ") || tpe.contains(s" $k "))) {
KeepTypeKeywords.foldLeft(tpe) { (cur, repl) =>
val prefix = s"$repl "
if (cur.startsWith(prefix)) {
prefix + cur.substring(prefix.length).replace(" ", "")
val prefixStartsWith = s"$repl "
val prefixContains = s" $repl "
if (cur.startsWith(prefixStartsWith)) {
prefixStartsWith + replaceWhitespaceAfterTypeKeyword(cur.substring(prefixStartsWith.length))
} else if (cur.contains(prefixContains)) {
val front = tpe.substring(0, tpe.indexOf(prefixContains))
val back = tpe.substring(tpe.indexOf(prefixContains) + prefixContains.length)
s"${replaceWhitespaceAfterTypeKeyword(front)}$prefixContains${replaceWhitespaceAfterTypeKeyword(back)}"
} else {
cur
}
Expand Down Expand Up @@ -324,7 +332,7 @@ trait AstCreatorHelper(implicit withSchemaValidation: ValidationMode) { this: As
}

private def typeForCPPAstNamedTypeSpecifier(s: ICPPASTNamedTypeSpecifier, stripKeywords: Boolean): String = {
val tpe = safeGetBinding(s).map(_.toString.replace(" ", "")).getOrElse(ASTStringUtil.getReturnTypeString(s, null))
val tpe = safeGetBinding(s).map(_.toString).getOrElse(s.getRawSignature)
cleanType(tpe, stripKeywords)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -308,14 +308,13 @@ trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { t
val code = s"for ($codeInit$codeCond;$codeIter)"
val forNode = controlStructureNode(forStmt, ControlStructureTypes.FOR, code)

val initAstBlock = blockNode(forStmt, Defines.Empty, registerType(Defines.Void)).order(1)
scope.pushNewScope(initAstBlock)
val initAst = blockAst(initAstBlock, nullSafeAst(forStmt.getInitializerStatement).toList)
val compareAst = astForConditionExpression(forStmt.getConditionExpression, Option(2))
val updateAst = nullSafeAst(forStmt.getIterationExpression, 3)
val bodyAsts = nullSafeAst(forStmt.getBody, 4)
scope.popScope()
forAst(forNode, Seq.empty, Seq(initAst), Seq(compareAst), Seq(updateAst), bodyAsts)
val (localAsts, initAsts) =
nullSafeAst(forStmt.getInitializerStatement).partition(_.root.exists(_.isInstanceOf[NewLocal]))
setArgumentIndices(initAsts)
val compareAst = astForConditionExpression(forStmt.getConditionExpression)
val updateAst = nullSafeAst(forStmt.getIterationExpression)
val bodyAsts = nullSafeAst(forStmt.getBody)
forAst(forNode, localAsts, initAsts, Seq(compareAst), Seq(updateAst), bodyAsts)
}

private def astForRangedFor(forStmt: ICPPASTRangeBasedForStatement): Ast = {
Expand All @@ -325,14 +324,18 @@ trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { t
val forNode = controlStructureNode(forStmt, ControlStructureTypes.FOR, code)
forStmt.getDeclaration match {
case declaration: ICPPASTStructuredBindingDeclaration =>
val initAsts = astsForStructuredBindingDeclaration(declaration, Some(forStmt.getInitializerClause))
val bodyAsts = nullSafeAst(forStmt.getBody, 4)
controlStructureAst(forNode, None, (initAsts ++ bodyAsts).toList)
val (localAsts, initAsts) = astsForStructuredBindingDeclaration(declaration, Some(forStmt.getInitializerClause))
.partition(_.root.exists(_.isInstanceOf[NewLocal]))
setArgumentIndices(initAsts)
val bodyAst = nullSafeAst(forStmt.getBody)
forAst(forNode, localAsts, initAsts.filterNot(_.nodes.isEmpty), Seq.empty, Seq.empty, bodyAst)
case _ =>
val initAst = astForNode(forStmt.getInitializerClause)
val declAst = astsForDeclaration(forStmt.getDeclaration)
val stmtAst = nullSafeAst(forStmt.getBody)
controlStructureAst(forNode, None, Seq(initAst) ++ declAst ++ stmtAst)
val init = astForNode(forStmt.getInitializerClause)
val declAsts = astsForDeclaration(forStmt.getDeclaration)
setArgumentIndices(init +: declAsts)
val (localAsts, initAsts) = (init +: declAsts).partition(_.root.exists(_.isInstanceOf[NewLocal]))
val bodyAst = nullSafeAst(forStmt.getBody)
forAst(forNode, localAsts, initAsts.filterNot(_.nodes.isEmpty), Seq.empty, Seq.empty, bodyAst)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -700,15 +700,16 @@ class AstCreationPassTests extends AstC2CpgSuite {
)
inside(cpg.method.nameExact("method").controlStructure.l) { case List(forStmt) =>
forStmt.controlStructureType shouldBe ControlStructureTypes.FOR
inside(forStmt.astChildren.order(1).l) { case List(ident: Identifier) =>
ident.code shouldBe "list"
}
inside(forStmt.astChildren.order(2).l) { case List(x: Local) =>
inside(forStmt.astChildren.isLocal.l) { case List(x: Local) =>
x.name shouldBe "x"
x.typeFullName shouldBe "int"
x.code shouldBe "int x"
}
inside(forStmt.astChildren.order(3).l) { case List(block: Block) =>
// for the expected orders see CfgCreator.cfgForForStatement
inside(forStmt.astChildren.order(2).l) { case List(ident: Identifier) =>
ident.code shouldBe "list"
}
inside(forStmt.astChildren.order(5).l) { case List(block: Block) =>
block.astChildren.isCall.code.l shouldBe List("z = x")
}
}
Expand All @@ -726,7 +727,7 @@ class AstCreationPassTests extends AstC2CpgSuite {
)
inside(cpg.method.nameExact("method").controlStructure.l) { case List(forStmt) =>
forStmt.controlStructureType shouldBe ControlStructureTypes.FOR
forStmt.astChildren.isCall.code.l shouldBe List(
forStmt.astChildren.isBlock.astChildren.isCall.code.l shouldBe List(
"anonymous_tmp_0 = foo",
"a = anonymous_tmp_0[0]",
"b = anonymous_tmp_0[1]"
Expand Down Expand Up @@ -819,7 +820,7 @@ class AstCreationPassTests extends AstC2CpgSuite {
""".stripMargin)
val List(forLoop) = cpg.controlStructure.l
val List(conditionBlock) = forLoop.condition.collectAll[Block].l
conditionBlock.argumentIndex shouldBe 2
conditionBlock.order shouldBe 2
val List(assignmentCall, greaterCall) = conditionBlock.astChildren.collectAll[Call].l
assignmentCall.argumentIndex shouldBe 1
assignmentCall.code shouldBe "b = something()"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,26 +92,26 @@ class ControlStructureTests extends C2CpgSuite(FileDefaults.CppExt) {
"should be correct for for-loop with multiple assignments" in {
inside(cpg.controlStructure.l) { case List(forLoop) =>
forLoop.controlStructureType shouldBe ControlStructureTypes.FOR
inside(forLoop.astChildren.order(1).l) { case List(assignmentBlock) =>
inside(assignmentBlock.astChildren.l) { case List(localX, localY, assignmentX, assignmentY) =>
localX.code shouldBe "int x"
localX.order shouldBe 1
localY.code shouldBe "int y"
localY.order shouldBe 2
inside(forLoop.astChildren.isLocal.l) { case List(localX, localY) =>
localX.code shouldBe "int x"
localY.code shouldBe "int y"
}
inside(forLoop.astChildren.order(3).l) { case List(assignmentBlock) =>
inside(assignmentBlock.astChildren.l) { case List(assignmentX, assignmentY) =>
assignmentX.code shouldBe "x=1"
assignmentX.order shouldBe 3
assignmentX.order shouldBe 1
assignmentY.code shouldBe "y=1"
assignmentY.order shouldBe 4
assignmentY.order shouldBe 2
}
}
inside(forLoop.condition.l) { case List(x) =>
x.code shouldBe "x"
x.order shouldBe 2
x.order shouldBe 4
}
inside(forLoop.astChildren.order(3).l) { case List(updateX) =>
inside(forLoop.astChildren.order(5).l) { case List(updateX) =>
updateX.code shouldBe "--x"
}
inside(forLoop.astChildren.order(4).l) { case List(loopBody) =>
inside(forLoop.astChildren.order(6).l) { case List(loopBody) =>
loopBody.astChildren.isCall.head.code shouldBe "bar()"
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class TemplateTypeTests extends C2CpgSuite(fileSuffix = FileDefaults.CppExt) {
typeDeclA.aliasTypeFullName shouldBe Option("X<int>")
typeDeclB.name shouldBe "B"
typeDeclB.fullName shouldBe "B"
typeDeclB.aliasTypeFullName shouldBe Option("Y<int, char>")
typeDeclB.aliasTypeFullName shouldBe Option("Y<int,char>")
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,79 +7,97 @@ import io.shiftleft.semanticcpg.language.*
*/
class LocalQueryTests extends C2CpgSuite {

private val cpg = code("""
| struct node {
| int value;
| struct node *next;
| };
|
| void free_list(struct node *head) {
| struct node *q;
| for (struct node *p = head; p != NULL; p = q) {
| q = p->next;
| free(p);
| }
| }
|
| int flow(int p0) {
| int a = p0;
| int b = a;
| int c = 0x31;
| int z = b + c;
| z++;
| int x = z;
| return x;
| }
|
| void test() {
| static int a, b, c;
| wchar_t *foo;
| int d[10], e = 1;
| }
| """.stripMargin)

"should allow to query for all locals" in {
cpg.local.name.toSetMutable shouldBe Set("a", "b", "c", "e", "d", "z", "x", "q", "p", "foo")
"local query example 1" should {
"allow to query for the local" in {
val cpg = code(
"""
|void foo() {
| static const Foo::Bar bar{};
|}
|""".stripMargin,
"test.cpp"
)
val List(barLocal) = cpg.method.name("foo").local.l
barLocal.name shouldBe "bar"
barLocal.typeFullName shouldBe "Foo.Bar"
barLocal.code shouldBe "static const Foo.Bar bar"
}
}

"should prove correct (name, type) pairs for locals" in {
inside(cpg.method.name("free_list").local.l) { case List(q, p) =>
q.name shouldBe "q"
q.typeFullName shouldBe "node*"
q.code shouldBe "struct node* q"
p.name shouldBe "p"
p.typeFullName shouldBe "node*"
p.code shouldBe "struct node* p"
"local query example 2" should {
val cpg = code("""
| struct node {
| int value;
| struct node *next;
| };
|
| void free_list(struct node *head) {
| struct node *q;
| for (struct node *p = head; p != NULL; p = q) {
| q = p->next;
| free(p);
| }
| }
|
| int flow(int p0) {
| int a = p0;
| int b = a;
| int c = 0x31;
| int z = b + c;
| z++;
| int x = z;
| return x;
| }
|
| void test() {
| static int a, b, c;
| wchar_t *foo;
| int d[10], e = 1;
| }
| """.stripMargin)

"should allow to query for all locals" in {
cpg.local.name.toSetMutable shouldBe Set("a", "b", "c", "e", "d", "z", "x", "q", "p", "foo")
}
}

"should prove correct (name, type, code) pairs for locals" in {
inside(cpg.method.name("test").local.l) { case List(a, b, c, foo, d, e) =>
a.name shouldBe "a"
a.typeFullName shouldBe "int"
a.code shouldBe "static int a"
b.name shouldBe "b"
b.typeFullName shouldBe "int"
b.code shouldBe "static int b"
c.name shouldBe "c"
c.typeFullName shouldBe "int"
c.code shouldBe "static int c"
foo.name shouldBe "foo"
foo.typeFullName shouldBe "wchar_t*"
foo.code shouldBe "wchar_t* foo"
d.name shouldBe "d"
d.typeFullName shouldBe "int[10]"
d.code shouldBe "int[10] d"
e.name shouldBe "e"
e.typeFullName shouldBe "int"
e.code shouldBe "int e"
"should prove correct (name, type) pairs for locals" in {
inside(cpg.method.name("free_list").local.l) { case List(q, p) =>
q.name shouldBe "q"
q.typeFullName shouldBe "node*"
q.code shouldBe "struct node* q"
p.name shouldBe "p"
p.typeFullName shouldBe "node*"
p.code shouldBe "struct node* p"
}
}
}

"should allow finding filenames by local regex" in {
val filename = cpg.local.name("a*").file.name.headOption
filename should not be empty
filename.head.endsWith(".c") shouldBe true
}
"should prove correct (name, type, code) pairs for locals" in {
inside(cpg.method.name("test").local.l) { case List(a, b, c, foo, d, e) =>
a.name shouldBe "a"
a.typeFullName shouldBe "int"
a.code shouldBe "static int a"
b.name shouldBe "b"
b.typeFullName shouldBe "int"
b.code shouldBe "static int b"
c.name shouldBe "c"
c.typeFullName shouldBe "int"
c.code shouldBe "static int c"
foo.name shouldBe "foo"
foo.typeFullName shouldBe "wchar_t*"
foo.code shouldBe "wchar_t* foo"
d.name shouldBe "d"
d.typeFullName shouldBe "int[10]"
d.code shouldBe "int[10] d"
e.name shouldBe "e"
e.typeFullName shouldBe "int"
e.code shouldBe "int e"
}
}

"should allow finding filenames by local regex" in {
val filename = cpg.local.name("a*").file.name.headOption
filename should not be empty
filename.head.endsWith(".c") shouldBe true
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,13 @@ abstract class AstCreatorBase(filename: String)(implicit withSchemaValidation: V
): Ast =
forAst(forNode, locals, initAsts, conditionAsts, updateAsts, Seq(bodyAst))

private def setOrderExplicitly(ast: Ast, order: Int): Ast = {
ast.root match {
case Some(value: ExpressionNew) => value.order(order); ast
case _ => ast
}
}

def forAst(
forNode: NewControlStructure,
locals: Seq[Ast],
Expand All @@ -206,12 +213,15 @@ abstract class AstCreatorBase(filename: String)(implicit withSchemaValidation: V
updateAsts: Seq[Ast],
bodyAsts: Seq[Ast]
): Ast = {
val lineNumber = forNode.lineNumber
val lineNumber = forNode.lineNumber
val numOfLocals = locals.size
// for the expected orders see CfgCreator.cfgForForStatement
if (bodyAsts.nonEmpty) setOrderExplicitly(bodyAsts.head, numOfLocals + 4)
Ast(forNode)
.withChildren(locals)
.withChild(wrapMultipleInBlock(initAsts, lineNumber))
.withChild(wrapMultipleInBlock(conditionAsts, lineNumber))
.withChild(wrapMultipleInBlock(updateAsts, lineNumber))
.withChild(setOrderExplicitly(wrapMultipleInBlock(initAsts, lineNumber), numOfLocals + 1))
.withChild(setOrderExplicitly(wrapMultipleInBlock(conditionAsts, lineNumber), numOfLocals + 2))
.withChild(setOrderExplicitly(wrapMultipleInBlock(updateAsts, lineNumber), numOfLocals + 3))
.withChildren(bodyAsts)
.withConditionEdges(forNode, conditionAsts.flatMap(_.root).toList)
}
Expand Down
Loading