diff --git a/framework/codemodder-base/src/main/java/io/codemodder/remediation/javadeserialization/DefaultJavaDeserializationRemediator.java b/framework/codemodder-base/src/main/java/io/codemodder/remediation/javadeserialization/DefaultJavaDeserializationRemediator.java new file mode 100644 index 000000000..ecb70419d --- /dev/null +++ b/framework/codemodder-base/src/main/java/io/codemodder/remediation/javadeserialization/DefaultJavaDeserializationRemediator.java @@ -0,0 +1,135 @@ +package io.codemodder.remediation.javadeserialization; + +import static io.codemodder.javaparser.JavaParserTransformer.replace; + +import com.github.javaparser.ast.CompilationUnit; +import com.github.javaparser.ast.Node; +import com.github.javaparser.ast.body.VariableDeclarator; +import com.github.javaparser.ast.expr.Expression; +import com.github.javaparser.ast.expr.MethodCallExpr; +import com.github.javaparser.ast.expr.ObjectCreationExpr; +import io.codemodder.CodemodChange; +import io.codemodder.CodemodFileScanningResult; +import io.codemodder.DependencyGAV; +import io.codemodder.ast.ASTs; +import io.codemodder.ast.LocalDeclaration; +import io.codemodder.codetf.DetectorRule; +import io.codemodder.codetf.FixedFinding; +import io.codemodder.codetf.UnfixedFinding; +import io.codemodder.remediation.FixCandidate; +import io.codemodder.remediation.FixCandidateSearchResults; +import io.codemodder.remediation.FixCandidateSearcher; +import io.github.pixee.security.ObjectInputFilters; +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import java.util.function.Function; + +final class DefaultJavaDeserializationRemediator implements JavaDeserializationRemediator { + + @Override + public CodemodFileScanningResult remediateAll( + final CompilationUnit cu, + final String path, + final DetectorRule detectorRule, + final List issuesForFile, + final Function getKey, + final Function getLine, + final Function getColumn) { + FixCandidateSearcher searcher = + new FixCandidateSearcher.Builder() + .withMethodName("readObject") + .withMatcher(mce -> mce.getScope().isPresent()) + .withMatcher(mce -> mce.getArguments().isEmpty()) + .build(); + + FixCandidateSearchResults results = + searcher.search(cu, path, detectorRule, issuesForFile, getKey, getLine, getColumn); + + List changes = new ArrayList<>(); + List unfixedFindings = new ArrayList<>(); + for (FixCandidate fixCandidate : results.fixCandidates()) { + List issues = fixCandidate.issues(); + MethodCallExpr call = fixCandidate.methodCall(); + // get the declaration of the ObjectInputStream + Expression callScope = call.getScope().get(); + if (!callScope.isNameExpr()) { + // can't fix these + issues.stream() + .map( + i -> + new UnfixedFinding( + getKey.apply(i), detectorRule, path, getLine.apply(i), "Unexpected shape")) + .forEach(unfixedFindings::add); + continue; + } + + Optional declaration = + ASTs.findEarliestLocalDeclarationOf(callScope.asNameExpr().getName()); + if (declaration.isEmpty()) { + issues.stream() + .map( + i -> + new UnfixedFinding( + getKey.apply(i), + detectorRule, + path, + getLine.apply(i), + "No declaration found")) + .forEach(unfixedFindings::add); + continue; + } + + LocalDeclaration localDeclaration = declaration.get(); + Node varDeclarationAndExpr = localDeclaration.getDeclaration(); + if (varDeclarationAndExpr instanceof VariableDeclarator varDec) { + Optional initializer = varDec.getInitializer(); + if (initializer.isEmpty()) { + issues.stream() + .map( + i -> + new UnfixedFinding( + getKey.apply(i), + detectorRule, + path, + getLine.apply(i), + "No initializer found")) + .forEach(unfixedFindings::add); + continue; + } + + Expression expression = initializer.get(); + if (expression instanceof ObjectCreationExpr objCreation) { + fixObjectInputStreamCreation(objCreation); + CodemodChange change = + CodemodChange.from( + getLine.apply(issues.get(0)), + List.of(DependencyGAV.JAVA_SECURITY_TOOLKIT), + issues.stream() + .map(i -> new FixedFinding(getKey.apply(i), detectorRule)) + .toList()); + changes.add(change); + } + } else { + issues.stream() + .map( + i -> + new UnfixedFinding( + getKey.apply(i), + detectorRule, + path, + getLine.apply(i), + "Unexpected declaration type")) + .forEach(unfixedFindings::add); + } + } + return CodemodFileScanningResult.from(changes, unfixedFindings); + } + + private void fixObjectInputStreamCreation(final ObjectCreationExpr objCreation) { + replace(objCreation) + .withStaticMethod(ObjectInputFilters.class.getName(), "createSafeObjectInputStream") + .withStaticImport() + .withSameArguments(); + } +} diff --git a/framework/codemodder-base/src/main/java/io/codemodder/remediation/javadeserialization/JavaDeserializationRemediator.java b/framework/codemodder-base/src/main/java/io/codemodder/remediation/javadeserialization/JavaDeserializationRemediator.java new file mode 100644 index 000000000..1040e37a2 --- /dev/null +++ b/framework/codemodder-base/src/main/java/io/codemodder/remediation/javadeserialization/JavaDeserializationRemediator.java @@ -0,0 +1,24 @@ +package io.codemodder.remediation.javadeserialization; + +import com.github.javaparser.ast.CompilationUnit; +import io.codemodder.CodemodFileScanningResult; +import io.codemodder.codetf.DetectorRule; +import java.util.List; +import java.util.function.Function; + +/** Remediates Java deserialization vulnerabilities. */ +public interface JavaDeserializationRemediator { + + /** Remediate all Java deserialization vulnerabilities in the given compilation unit. */ + CodemodFileScanningResult remediateAll( + CompilationUnit cu, + String path, + DetectorRule detectorRule, + List issuesForFile, + Function getKey, + Function getLine, + Function getColumn); + + /** The default header injection remediation strategy. */ + JavaDeserializationRemediator DEFAULT = new DefaultJavaDeserializationRemediator(); +} diff --git a/framework/codemodder-base/src/test/java/io/codemodder/remediation/javadeserialization/DefaultJavaDeserializationRemediatorTest.java b/framework/codemodder-base/src/test/java/io/codemodder/remediation/javadeserialization/DefaultJavaDeserializationRemediatorTest.java new file mode 100644 index 000000000..c5ee9b9ec --- /dev/null +++ b/framework/codemodder-base/src/test/java/io/codemodder/remediation/javadeserialization/DefaultJavaDeserializationRemediatorTest.java @@ -0,0 +1,132 @@ +package io.codemodder.remediation.javadeserialization; + +import static org.assertj.core.api.Assertions.assertThat; + +import com.github.javaparser.StaticJavaParser; +import com.github.javaparser.ast.CompilationUnit; +import com.github.javaparser.printer.lexicalpreservation.LexicalPreservingPrinter; +import io.codemodder.CodemodChange; +import io.codemodder.CodemodFileScanningResult; +import io.codemodder.DependencyGAV; +import io.codemodder.codetf.DetectorRule; +import io.codemodder.codetf.FixedFinding; +import io.codemodder.codetf.UnfixedFinding; +import java.util.List; +import java.util.stream.Stream; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; + +final class DefaultJavaDeserializationRemediatorTest { + + private DefaultJavaDeserializationRemediator remediator; + private DetectorRule rule; + + @BeforeEach + void setup() { + remediator = new DefaultJavaDeserializationRemediator(); + rule = new DetectorRule("untrusted-deserialization", "Untrusted Deserialization", null); + } + + private static Stream unfixableSamples() { + return Stream.of( + Arguments.of( + """ + import java.io.ObjectInputStream; + class Foo { + void bar(ObjectInputStream ois) { + Acme acme = ois.readObject(); + } + } + """, + 4, + "Unexpected declaration type"), + Arguments.of( + """ + import java.io.ObjectInputStream; + class Foo { + void bar() { + Acme acme = getOis().readObject(); + } + } + """, + 4, + "Unexpected shape")); + } + + @ParameterizedTest + @MethodSource("unfixableSamples") + void it_doesnt_handle_unfixables(final String badCode, final int line, final String reason) { + CompilationUnit cu = StaticJavaParser.parse(badCode); + LexicalPreservingPrinter.setup(cu); + + CodemodFileScanningResult result = + remediator.remediateAll( + cu, "path", rule, List.of(new Object()), o -> "id", o -> line, o -> null); + assertThat(result.unfixedFindings()).hasSize(1); + assertThat(result.changes()).isEmpty(); + UnfixedFinding unfixedFinding = result.unfixedFindings().get(0); + assertThat(unfixedFinding.getReason()).isEqualTo(reason); + assertThat(unfixedFinding.getRule()).isEqualTo(rule); + assertThat(unfixedFinding.getLine()).isEqualTo(line); + assertThat(unfixedFinding.getPath()).isEqualTo("path"); + } + + @Test + void it_fixes_java_deserialization() { + + String fixableCode = + """ + package com.acme; + import java.io.ObjectInputStream; + import java.io.InputStream; + + class Foo { + Acme readAcme(InputStream is) { + ObjectInputStream ois = new ObjectInputStream(is); + // read the obj + Acme acme = (Acme) ois.readObject(); + return acme; + } + } + """; + + CompilationUnit cu = StaticJavaParser.parse(fixableCode); + LexicalPreservingPrinter.setup(cu); + + CodemodFileScanningResult result = + remediator.remediateAll( + cu, "path", rule, List.of(new Object()), o -> "id", o -> 9, o -> null); + assertThat(result.unfixedFindings()).isEmpty(); + assertThat(result.changes()).hasSize(1); + CodemodChange change = result.changes().get(0); + assertThat(change.lineNumber()).isEqualTo(9); + List fixedFindings = change.getFixedFindings(); + assertThat(fixedFindings).hasSize(1); + assertThat(change.getDependenciesNeeded()).containsExactly(DependencyGAV.JAVA_SECURITY_TOOLKIT); + + assertThat(fixedFindings.get(0).getId()).isEqualTo("id"); + assertThat(fixedFindings.get(0).getRule()).isEqualTo(rule); + + String afterCode = LexicalPreservingPrinter.print(cu); + assertThat(afterCode) + .isEqualToIgnoringWhitespace( + """ + package com.acme; + import static io.github.pixee.security.ObjectInputFilters.createSafeObjectInputStream; + import java.io.ObjectInputStream; + import java.io.InputStream; + + class Foo { + Acme readAcme(InputStream is) { + ObjectInputStream ois = createSafeObjectInputStream(is); + // read the obj + Acme acme = (Acme) ois.readObject(); + return acme; + } + } + """); + } +}