Skip to content

Commit

Permalink
Add proper support for java record classes (#5171)
Browse files Browse the repository at this point in the history
  • Loading branch information
johannescoetzee authored Dec 9, 2024
1 parent 582ede4 commit 5af4cc0
Show file tree
Hide file tree
Showing 11 changed files with 927 additions and 81 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,10 @@ class AstCreator(
TypeConstants.Any
}

private[astcreation] def isResolvedTypeFullName(typeFullName: String): Boolean = {
typeFullName != TypeConstants.Any && !typeFullName.startsWith(Defines.UnresolvedNamespace)
}

/** Custom printer that omits comments. To be used by [[code]] */
private val codePrinterOptions = new DefaultPrinterConfiguration()
.removeOption(new DefaultConfigurationOption(ConfigOption.PRINT_COMMENTS))
Expand Down Expand Up @@ -372,8 +376,10 @@ class AstCreator(
case _ => None
}

def argumentTypesForMethodLike(maybeResolvedMethodLike: Try[ResolvedMethodLikeDeclaration]): Option[List[String]] = {
maybeResolvedMethodLike.toOption
def argumentTypesForMethodLike(
maybeResolvedMethodLike: Option[ResolvedMethodLikeDeclaration]
): Option[List[String]] = {
maybeResolvedMethodLike
.flatMap(calcParameterTypes(_, ResolvedTypeParametersMap.empty()))
}

Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,17 @@ package io.joern.javasrc2cpg.astcreation.declarations

import com.github.javaparser.ast.body.{
AnnotationDeclaration,
AnnotationMemberDeclaration,
BodyDeclaration,
ClassOrInterfaceDeclaration,
CompactConstructorDeclaration,
ConstructorDeclaration,
EnumConstantDeclaration,
EnumDeclaration,
FieldDeclaration,
InitializerDeclaration,
MethodDeclaration,
RecordDeclaration,
TypeDeclaration,
VariableDeclarator
}
Expand Down Expand Up @@ -57,9 +61,6 @@ import scala.jdk.CollectionConverters.*
import scala.util.{Success, Try}
import com.github.javaparser.ast.expr.ObjectCreationExpr
import com.github.javaparser.ast.stmt.LocalClassDeclarationStmt
import com.github.javaparser.ast.body.AnnotationMemberDeclaration
import com.github.javaparser.ast.body.CompactConstructorDeclaration
import com.github.javaparser.ast.body.EnumDeclaration
import io.joern.javasrc2cpg.scope.Scope.ScopeVariable
import com.github.javaparser.ast.Node
import com.github.javaparser.resolution.types.ResolvedReferenceType
Expand Down Expand Up @@ -117,7 +118,7 @@ private[declarations] trait AstForTypeDeclsCreator { this: AstCreator =>
methodDeclaration.getNameAsString
}.toSet

scope.pushTypeDeclScope(typeDeclRoot, scope.isEnclosingScopeStatic, declaredMethodNames)
scope.pushTypeDeclScope(typeDeclRoot, scope.isEnclosingScopeStatic, declaredMethodNames, Nil)
val memberAsts = astsForTypeDeclMembers(expr, body, isInterface = false, typeFullName)

val localDecls = scope.localDeclsInScope
Expand Down Expand Up @@ -173,7 +174,16 @@ private[declarations] trait AstForTypeDeclsCreator { this: AstCreator =>
createTypeDeclNode(typeDeclaration, astParentType, astParentFullName, isInterface, fullNameOverride)

val declaredMethodNames = typeDeclaration.getMethods.asScala.map(_.getNameAsString).toSet
scope.pushTypeDeclScope(typeDeclRoot, typeDeclaration.isStatic, declaredMethodNames)

val (recordParameters, recordParameterAsts) = typeDeclaration match {
case recordDeclaration: RecordDeclaration =>
val parameters = recordDeclaration.getParameters.asScala.toList
val asts = astsForRecordParameters(recordDeclaration, typeDeclRoot.fullName)
(parameters, asts)
case _ => (Nil, Nil)
}

scope.pushTypeDeclScope(typeDeclRoot, typeDeclaration.isStatic, declaredMethodNames, recordParameters)
addTypeDeclTypeParamsToScope(typeDeclaration)

val annotationAsts = typeDeclaration.getAnnotations.asScala.map(astForAnnotationExpr)
Expand All @@ -182,6 +192,7 @@ private[declarations] trait AstForTypeDeclsCreator { this: AstCreator =>
case enumDeclaration: EnumDeclaration => enumDeclaration.getEntries.asScala.toList
case _ => Nil
}

val memberAsts =
astsForTypeDeclMembers(
typeDeclaration,
Expand All @@ -194,6 +205,7 @@ private[declarations] trait AstForTypeDeclsCreator { this: AstCreator =>
val lambdaMethods = scope.lambdaMethodsInScope

val typeDeclAst = Ast(typeDeclRoot)
.withChildren(recordParameterAsts)
.withChildren(memberAsts)
.withChildren(annotationAsts)
.withChildren(localDecls)
Expand Down Expand Up @@ -228,6 +240,32 @@ private[declarations] trait AstForTypeDeclsCreator { this: AstCreator =>
typeDeclAst
}

private def astsForRecordParameters(recordDeclaration: RecordDeclaration, recordTypeFullName: String): List[Ast] = {
val explicitMethodNames = recordDeclaration.getMethods.asScala.map(_.getNameAsString).toSet

recordDeclaration.getParameters.asScala.toList.flatMap { parameter =>
val parameterName = parameter.getNameAsString
val parameterTypeFullName = tryWithSafeStackOverflow {
val typ = parameter.getType
scope
.lookupScopeType(typ.asString())
.map(_.typeFullName)
.orElse(typeInfoCalc.fullName(typ))
.getOrElse(defaultTypeFallback(typ))
}.toOption.getOrElse(defaultTypeFallback())

val parameterMember = memberNode(parameter, parameterName, code(parameter), parameterTypeFullName)
val privateModifier = newModifierNode(ModifierTypes.PRIVATE)
val memberAst = Ast(parameterMember).withChild(Ast(privateModifier))

val accessorMethodAst = Option.unless(explicitMethodNames.contains(parameterName))(
astForRecordParameterAccessor(parameter, recordTypeFullName, parameterName, parameterTypeFullName)
)

memberAst :: accessorMethodAst.toList
}
}

private def bindingTypeForReferenceType(typ: ResolvedReferenceType): Option[JavaparserBindingDeclType] = {
typ.getTypeDeclaration.toScala.map(typeDecl =>
scope.getDeclBinding(typeDecl.getName) match {
Expand Down Expand Up @@ -343,19 +381,35 @@ private[declarations] trait AstForTypeDeclsCreator { this: AstCreator =>
}

val constructorAstMap = astsForConstructors(
members.collect { case constructor: ConstructorDeclaration =>
constructor
members.collect {
case constructor: ConstructorDeclaration => constructor
case constructor: CompactConstructorDeclaration => constructor
},
instanceFields
)

val membersAsts = membersAstPairs.flatMap {
case (constructor: ConstructorDeclaration, _) =>
constructorAstMap.get(constructor)
case (_, asts) => asts
case (constructor: ConstructorDeclaration, _) => constructorAstMap.get(constructor)
case (constructor: CompactConstructorDeclaration, _) => constructorAstMap.get(constructor)
case (_, asts) => asts
}

val defaultConstructorAst = Option.when(!(isInterface || members.exists(_.isInstanceOf[ConstructorDeclaration]))) {
val hasCanonicalConstructor = scope.enclosingTypeDecl.get.recordParameters match {
case Nil => members.exists(member => member.isConstructorDeclaration || member.isCompactConstructorDeclaration)

case recordParameters =>
members.collect {
case compactConstructorDeclaration: CompactConstructorDeclaration => compactConstructorDeclaration

case constructorDeclaration: ConstructorDeclaration
if constructorDeclaration.getParameters.asScala
.map(_.getType)
.toList
.equals(recordParameters.map(_.getType)) =>
constructorDeclaration
}.nonEmpty
}
val defaultConstructorAst = Option.when(!(isInterface || hasCanonicalConstructor)) {
astForDefaultConstructor(originNode, instanceFields)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ trait AstForCallExpressionsCreator { this: AstCreator =>
val expressionTypeFullName =
expressionReturnTypeFullName(call).orElse(getTypeFullName(expectedReturnType)).map(typeInfoCalc.registerType)

val argumentTypes = argumentTypesForMethodLike(maybeResolvedCall)
val argumentTypes = argumentTypesForMethodLike(maybeResolvedCall.toOption)
val returnType = maybeResolvedCall
.map { resolvedCall =>
typeInfoCalc.fullName(resolvedCall.getReturnType, ResolvedTypeParametersMap.empty())
Expand Down Expand Up @@ -232,7 +232,7 @@ trait AstForCallExpressionsCreator { this: AstCreator =>
scope.addLocalDecl(anonymousClassDecl)
}

val argumentTypes = argumentTypesForMethodLike(maybeResolvedExpr)
val argumentTypes = argumentTypesForMethodLike(maybeResolvedExpr.toOption)

val allocNode = newOperatorCallNode(
Operators.alloc,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ private[expressions] trait AstForLambdasCreator { this: AstCreator =>
.find { identifier => identifier.name == NameConstants.This || identifier.name == NameConstants.Super }
.map { _ =>
val typeFullName = scope.enclosingTypeDecl.fullName
Ast(thisNodeForMethod(typeFullName, line(expr)))
Ast(thisNodeForMethod(typeFullName, line(expr), column(expr)))
}
.toList

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,7 @@ trait AstForSimpleExpressionsCreator { this: AstCreator =>

case Success(resolvedMethod) =>
val returnType = tryWithSafeStackOverflow(resolvedMethod.getReturnType).toOption.flatMap(typeInfoCalc.fullName)
val parameterTypes = argumentTypesForMethodLike(Success(resolvedMethod))
val parameterTypes = argumentTypesForMethodLike(Option(resolvedMethod))
composeSignature(returnType, parameterTypes, resolvedMethod.getNumberOfParams)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ trait AstForSimpleStatementsCreator { this: AstCreator =>
// TODO Handle super
val maybeResolved = tryWithSafeStackOverflow(stmt.resolve())
val args = argAstsForCall(stmt, maybeResolved, stmt.getArguments)
val argTypes = argumentTypesForMethodLike(maybeResolved)
val argTypes = argumentTypesForMethodLike(maybeResolved.toOption)

// TODO: We can do better than defaultTypeFallback() for the fallback type by looking at the enclosing
// type decl name or `extends X` name for `this` and `super` calls respectively.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package io.joern.javasrc2cpg.scope

import com.github.javaparser.ast.body.Parameter
import com.github.javaparser.ast.expr.TypePatternExpr
import io.joern.javasrc2cpg.scope.Scope.*
import io.joern.javasrc2cpg.scope.JavaScopeElement.*
Expand Down Expand Up @@ -175,7 +176,8 @@ object JavaScopeElement {
override val isStatic: Boolean,
private[scope] val capturedVariables: Map[String, CapturedVariable],
outerClassType: Option[String],
val declaredMethodNames: Set[String]
val declaredMethodNames: Set[String],
val recordParameters: List[Parameter]
)(implicit disableTypeFallback: Boolean)
extends JavaScopeElement(disableTypeFallback)
with TypeDeclContainer
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package io.joern.javasrc2cpg.scope

import com.github.javaparser.ast.body.Parameter
import com.github.javaparser.ast.expr.TypePatternExpr
import io.joern.javasrc2cpg.astcreation.ExpectedType
import io.joern.javasrc2cpg.scope.Scope.*
Expand Down Expand Up @@ -40,7 +41,12 @@ class Scope(implicit val withSchemaValidation: ValidationMode, val disableTypeFa
scopeStack = new FieldDeclScope(isStatic, name) :: scopeStack
}

def pushTypeDeclScope(typeDecl: NewTypeDecl, isStatic: Boolean, methodNames: Set[String] = Set.empty): Unit = {
def pushTypeDeclScope(
typeDecl: NewTypeDecl,
isStatic: Boolean,
methodNames: Set[String] = Set.empty,
recordParameters: List[Parameter] = Nil
): Unit = {
val captures = getCapturesForNewScope(isStatic)
val outerClassType = scopeStack.takeUntil(_.isInstanceOf[TypeDeclScope]) match {
case Nil => None
Expand All @@ -58,7 +64,8 @@ class Scope(implicit val withSchemaValidation: ValidationMode, val disableTypeFa
}
.flatten
}
scopeStack = new TypeDeclScope(typeDecl, isStatic, captures, outerClassType, methodNames) :: scopeStack
scopeStack =
new TypeDeclScope(typeDecl, isStatic, captures, outerClassType, methodNames, recordParameters) :: scopeStack
}

def pushNamespaceScope(namespace: NewNamespaceBlock): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class MethodParameterTests2 extends JavaSrcCode2CpgFixture {
param.order shouldBe 0
param.index shouldBe 0
param.lineNumber shouldBe Some(3)
param.columnNumber shouldBe None
param.columnNumber shouldBe Some(3)
param.typeFullName shouldBe "Foo"
param.evaluationStrategy shouldBe EvaluationStrategies.BY_SHARING
}
Expand Down
Loading

0 comments on commit 5af4cc0

Please sign in to comment.