Skip to content

Commit

Permalink
Added a Java deserialization remediator (#432)
Browse files Browse the repository at this point in the history
  • Loading branch information
nahsra authored Jul 26, 2024
1 parent f8af718 commit 848ff93
Show file tree
Hide file tree
Showing 3 changed files with 291 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
package io.codemodder.remediation.javadeserialization;

import static io.codemodder.javaparser.JavaParserTransformer.replace;

import com.github.javaparser.ast.CompilationUnit;
import com.github.javaparser.ast.Node;
import com.github.javaparser.ast.body.VariableDeclarator;
import com.github.javaparser.ast.expr.Expression;
import com.github.javaparser.ast.expr.MethodCallExpr;
import com.github.javaparser.ast.expr.ObjectCreationExpr;
import io.codemodder.CodemodChange;
import io.codemodder.CodemodFileScanningResult;
import io.codemodder.DependencyGAV;
import io.codemodder.ast.ASTs;
import io.codemodder.ast.LocalDeclaration;
import io.codemodder.codetf.DetectorRule;
import io.codemodder.codetf.FixedFinding;
import io.codemodder.codetf.UnfixedFinding;
import io.codemodder.remediation.FixCandidate;
import io.codemodder.remediation.FixCandidateSearchResults;
import io.codemodder.remediation.FixCandidateSearcher;
import io.github.pixee.security.ObjectInputFilters;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.function.Function;

final class DefaultJavaDeserializationRemediator implements JavaDeserializationRemediator {

@Override
public <T> CodemodFileScanningResult remediateAll(
final CompilationUnit cu,
final String path,
final DetectorRule detectorRule,
final List<T> issuesForFile,
final Function<T, String> getKey,
final Function<T, Integer> getLine,
final Function<T, Integer> getColumn) {
FixCandidateSearcher<T> searcher =
new FixCandidateSearcher.Builder<T>()
.withMethodName("readObject")
.withMatcher(mce -> mce.getScope().isPresent())
.withMatcher(mce -> mce.getArguments().isEmpty())
.build();

FixCandidateSearchResults<T> results =
searcher.search(cu, path, detectorRule, issuesForFile, getKey, getLine, getColumn);

List<CodemodChange> changes = new ArrayList<>();
List<UnfixedFinding> unfixedFindings = new ArrayList<>();
for (FixCandidate<T> fixCandidate : results.fixCandidates()) {
List<T> issues = fixCandidate.issues();
MethodCallExpr call = fixCandidate.methodCall();
// get the declaration of the ObjectInputStream
Expression callScope = call.getScope().get();
if (!callScope.isNameExpr()) {
// can't fix these
issues.stream()
.map(
i ->
new UnfixedFinding(
getKey.apply(i), detectorRule, path, getLine.apply(i), "Unexpected shape"))
.forEach(unfixedFindings::add);
continue;
}

Optional<LocalDeclaration> declaration =
ASTs.findEarliestLocalDeclarationOf(callScope.asNameExpr().getName());
if (declaration.isEmpty()) {
issues.stream()
.map(
i ->
new UnfixedFinding(
getKey.apply(i),
detectorRule,
path,
getLine.apply(i),
"No declaration found"))
.forEach(unfixedFindings::add);
continue;
}

LocalDeclaration localDeclaration = declaration.get();
Node varDeclarationAndExpr = localDeclaration.getDeclaration();
if (varDeclarationAndExpr instanceof VariableDeclarator varDec) {
Optional<Expression> initializer = varDec.getInitializer();
if (initializer.isEmpty()) {
issues.stream()
.map(
i ->
new UnfixedFinding(
getKey.apply(i),
detectorRule,
path,
getLine.apply(i),
"No initializer found"))
.forEach(unfixedFindings::add);
continue;
}

Expression expression = initializer.get();
if (expression instanceof ObjectCreationExpr objCreation) {
fixObjectInputStreamCreation(objCreation);
CodemodChange change =
CodemodChange.from(
getLine.apply(issues.get(0)),
List.of(DependencyGAV.JAVA_SECURITY_TOOLKIT),
issues.stream()
.map(i -> new FixedFinding(getKey.apply(i), detectorRule))
.toList());
changes.add(change);
}
} else {
issues.stream()
.map(
i ->
new UnfixedFinding(
getKey.apply(i),
detectorRule,
path,
getLine.apply(i),
"Unexpected declaration type"))
.forEach(unfixedFindings::add);
}
}
return CodemodFileScanningResult.from(changes, unfixedFindings);
}

private void fixObjectInputStreamCreation(final ObjectCreationExpr objCreation) {
replace(objCreation)
.withStaticMethod(ObjectInputFilters.class.getName(), "createSafeObjectInputStream")
.withStaticImport()
.withSameArguments();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package io.codemodder.remediation.javadeserialization;

import com.github.javaparser.ast.CompilationUnit;
import io.codemodder.CodemodFileScanningResult;
import io.codemodder.codetf.DetectorRule;
import java.util.List;
import java.util.function.Function;

/** Remediates Java deserialization vulnerabilities. */
public interface JavaDeserializationRemediator {

/** Remediate all Java deserialization vulnerabilities in the given compilation unit. */
<T> CodemodFileScanningResult remediateAll(
CompilationUnit cu,
String path,
DetectorRule detectorRule,
List<T> issuesForFile,
Function<T, String> getKey,
Function<T, Integer> getLine,
Function<T, Integer> getColumn);

/** The default header injection remediation strategy. */
JavaDeserializationRemediator DEFAULT = new DefaultJavaDeserializationRemediator();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
package io.codemodder.remediation.javadeserialization;

import static org.assertj.core.api.Assertions.assertThat;

import com.github.javaparser.StaticJavaParser;
import com.github.javaparser.ast.CompilationUnit;
import com.github.javaparser.printer.lexicalpreservation.LexicalPreservingPrinter;
import io.codemodder.CodemodChange;
import io.codemodder.CodemodFileScanningResult;
import io.codemodder.DependencyGAV;
import io.codemodder.codetf.DetectorRule;
import io.codemodder.codetf.FixedFinding;
import io.codemodder.codetf.UnfixedFinding;
import java.util.List;
import java.util.stream.Stream;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;

final class DefaultJavaDeserializationRemediatorTest {

private DefaultJavaDeserializationRemediator remediator;
private DetectorRule rule;

@BeforeEach
void setup() {
remediator = new DefaultJavaDeserializationRemediator();
rule = new DetectorRule("untrusted-deserialization", "Untrusted Deserialization", null);
}

private static Stream<Arguments> unfixableSamples() {
return Stream.of(
Arguments.of(
"""
import java.io.ObjectInputStream;
class Foo {
void bar(ObjectInputStream ois) {
Acme acme = ois.readObject();
}
}
""",
4,
"Unexpected declaration type"),
Arguments.of(
"""
import java.io.ObjectInputStream;
class Foo {
void bar() {
Acme acme = getOis().readObject();
}
}
""",
4,
"Unexpected shape"));
}

@ParameterizedTest
@MethodSource("unfixableSamples")
void it_doesnt_handle_unfixables(final String badCode, final int line, final String reason) {
CompilationUnit cu = StaticJavaParser.parse(badCode);
LexicalPreservingPrinter.setup(cu);

CodemodFileScanningResult result =
remediator.remediateAll(
cu, "path", rule, List.of(new Object()), o -> "id", o -> line, o -> null);
assertThat(result.unfixedFindings()).hasSize(1);
assertThat(result.changes()).isEmpty();
UnfixedFinding unfixedFinding = result.unfixedFindings().get(0);
assertThat(unfixedFinding.getReason()).isEqualTo(reason);
assertThat(unfixedFinding.getRule()).isEqualTo(rule);
assertThat(unfixedFinding.getLine()).isEqualTo(line);
assertThat(unfixedFinding.getPath()).isEqualTo("path");
}

@Test
void it_fixes_java_deserialization() {

String fixableCode =
"""
package com.acme;
import java.io.ObjectInputStream;
import java.io.InputStream;
class Foo {
Acme readAcme(InputStream is) {
ObjectInputStream ois = new ObjectInputStream(is);
// read the obj
Acme acme = (Acme) ois.readObject();
return acme;
}
}
""";

CompilationUnit cu = StaticJavaParser.parse(fixableCode);
LexicalPreservingPrinter.setup(cu);

CodemodFileScanningResult result =
remediator.remediateAll(
cu, "path", rule, List.of(new Object()), o -> "id", o -> 9, o -> null);
assertThat(result.unfixedFindings()).isEmpty();
assertThat(result.changes()).hasSize(1);
CodemodChange change = result.changes().get(0);
assertThat(change.lineNumber()).isEqualTo(9);
List<FixedFinding> fixedFindings = change.getFixedFindings();
assertThat(fixedFindings).hasSize(1);
assertThat(change.getDependenciesNeeded()).containsExactly(DependencyGAV.JAVA_SECURITY_TOOLKIT);

assertThat(fixedFindings.get(0).getId()).isEqualTo("id");
assertThat(fixedFindings.get(0).getRule()).isEqualTo(rule);

String afterCode = LexicalPreservingPrinter.print(cu);
assertThat(afterCode)
.isEqualToIgnoringWhitespace(
"""
package com.acme;
import static io.github.pixee.security.ObjectInputFilters.createSafeObjectInputStream;
import java.io.ObjectInputStream;
import java.io.InputStream;
class Foo {
Acme readAcme(InputStream is) {
ObjectInputStream ois = createSafeObjectInputStream(is);
// read the obj
Acme acme = (Acme) ois.readObject();
return acme;
}
}
""");
}
}

0 comments on commit 848ff93

Please sign in to comment.