Skip to content

Commit

Permalink
Improve XSS fixer and create CodeQL mapping (#467)
Browse files Browse the repository at this point in the history
  • Loading branch information
nahsra authored Nov 15, 2024
1 parent 0214068 commit d9b49a0
Show file tree
Hide file tree
Showing 6 changed files with 222 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ public static List<Class<? extends CodeChanger>> asList() {
CodeQLSSRFCodemod.class,
CodeQLStackTraceExposureCodemod.class,
CodeQLUnverifiedJwtCodemod.class,
CodeQLXSSCodemod.class,
CodeQLXXECodemod.class,
DeclareVariableOnSeparateLineCodemod.class,
DefectDojoSqlInjectionCodemod.class,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package io.codemodder.codemods.codeql;

import com.contrastsecurity.sarif.Result;
import com.github.javaparser.ast.CompilationUnit;
import io.codemodder.*;
import io.codemodder.codetf.DetectorRule;
import io.codemodder.providers.sarif.codeql.ProvidedCodeQLScan;
import io.codemodder.remediation.GenericRemediationMetadata;
import io.codemodder.remediation.Remediator;
import io.codemodder.remediation.xss.XSSRemediator;
import java.util.Optional;
import javax.inject.Inject;

/** A codemod for automatically fixing XSS from CodeQL. */
@Codemod(
id = "codeql:java/xss",
reviewGuidance = ReviewGuidance.MERGE_AFTER_CURSORY_REVIEW,
importance = Importance.HIGH,
executionPriority = CodemodExecutionPriority.HIGH)
public final class CodeQLXSSCodemod extends CodeQLRemediationCodemod {

private final Remediator<Result> remediator;

@Inject
public CodeQLXSSCodemod(@ProvidedCodeQLScan(ruleId = "java/xss") final RuleSarif sarif) {
super(GenericRemediationMetadata.XSS.reporter(), sarif);
this.remediator = new XSSRemediator<>();
}

@Override
public DetectorRule detectorRule() {
return new DetectorRule(
"xss",
"Cross-site scripting",
"https://codeql.github.com/codeql-query-help/java/java-xss/");
}

@Override
public CodemodFileScanningResult visit(
final CodemodInvocationContext context, final CompilationUnit cu) {
return remediator.remediateAll(
cu,
context.path().toString(),
detectorRule(),
ruleSarif.getResultsByLocationPath(context.path()),
SarifFindingKeyUtil::buildFindingId,
r -> r.getLocations().get(0).getPhysicalLocation().getRegion().getStartLine(),
r ->
Optional.ofNullable(
r.getLocations().get(0).getPhysicalLocation().getRegion().getEndLine()),
r ->
Optional.ofNullable(
r.getLocations().get(0).getPhysicalLocation().getRegion().getStartColumn()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,13 @@
import io.codemodder.remediation.SuccessOrReason;
import java.util.List;
import java.util.Optional;
import org.jetbrains.annotations.VisibleForTesting;

/**
* Fix strategy for XSS vulnerabilities where a variable is returned directly and that is what's
* vulnerable.
*/
final class NakedVariableReturnFixStrategy implements RemediationStrategy {

@Override
public SuccessOrReason fix(final CompilationUnit cu, final Node node) {
var maybeReturn = Optional.of(node).map(n -> n instanceof ReturnStmt ? (ReturnStmt) n : null);
Expand All @@ -25,8 +29,7 @@ public SuccessOrReason fix(final CompilationUnit cu, final Node node) {
return SuccessOrReason.success(List.of(DependencyGAV.OWASP_XSS_JAVA_ENCODER));
}

@VisibleForTesting
public static boolean match(final Node node) {
static boolean match(final Node node) {
return Optional.of(node)
.map(n -> n instanceof ReturnStmt ? (ReturnStmt) n : null)
.filter(rs -> rs.getExpression().isPresent())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,17 @@

import com.github.javaparser.ast.CompilationUnit;
import com.github.javaparser.ast.Node;
import com.github.javaparser.ast.expr.BinaryExpr;
import com.github.javaparser.ast.expr.Expression;
import com.github.javaparser.ast.expr.MethodCallExpr;
import io.codemodder.DependencyGAV;
import io.codemodder.remediation.RemediationStrategy;
import io.codemodder.remediation.SuccessOrReason;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import org.jetbrains.annotations.VisibleForTesting;

/** Fix strategy for XSS vulnerabilities where a variable is sent to a simple print/write call. */
final class PrintingMethodFixStrategy implements RemediationStrategy {

@Override
Expand All @@ -23,14 +25,55 @@ public SuccessOrReason fix(final CompilationUnit cu, final Node node) {
return SuccessOrReason.reason("Not a method call.");
}
MethodCallExpr call = maybeCall.get();
wrap(call.getArgument(0)).withStaticMethod("org.owasp.encoder.Encode", "forHtml", false);

Expression methodArgument = call.getArgument(0);

Optional<Expression> thingToWrap = findExpressionToWrap(methodArgument);
if (thingToWrap.isEmpty()) {
return SuccessOrReason.reason("Could not find recognize code shape to fix.");
}
Expression expressionToWrap = thingToWrap.get();
wrap(expressionToWrap).withStaticMethod("org.owasp.encoder.Encode", "forHtml", false);
return SuccessOrReason.success(List.of(DependencyGAV.OWASP_XSS_JAVA_ENCODER));
}

/**
* We handle 4 expression code shapes. <code>
* print(user.getName());
* print("Hello, " + user.getName());
* print(user.getName() + ", hello!");
* print("Hello, " + user.getName() + ", hello!");
* </code>
*
* <p>Note that we should only handle, for the tougher cases, string literals in combination with
* the given expression. Note any other combination of expressions.
*/
private Optional<Expression> findExpressionToWrap(final Expression expression) {
if (expression.isNameExpr()) {
return Optional.of(expression);
} else if (expression.isBinaryExpr()) {
BinaryExpr binaryExpr = expression.asBinaryExpr();
if (binaryExpr.getLeft().isBinaryExpr() && binaryExpr.getRight().isStringLiteralExpr()) {
BinaryExpr leftBinaryExpr = binaryExpr.getLeft().asBinaryExpr();
if (leftBinaryExpr.getLeft().isStringLiteralExpr()
&& !leftBinaryExpr.getRight().isStringLiteralExpr()) {
return Optional.of(leftBinaryExpr.getRight());
}
} else if (binaryExpr.getLeft().isStringLiteralExpr()
&& binaryExpr.getRight().isStringLiteralExpr()) {
return Optional.empty();
} else if (binaryExpr.getLeft().isStringLiteralExpr()) {
return Optional.of(binaryExpr.getRight());
} else if (binaryExpr.getRight().isStringLiteralExpr()) {
return Optional.of(binaryExpr.getLeft());
}
}
return Optional.empty();
}

private static final Set<String> writingMethodNames = Set.of("print", "println", "write");

@VisibleForTesting
public static boolean match(final Node node) {
static boolean match(final Node node) {
return Optional.of(node)
.map(n -> n instanceof MethodCallExpr ? (MethodCallExpr) n : null)
.filter(mce -> writingMethodNames.contains(mce.getNameAsString()))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
import java.util.Optional;
import java.util.function.Function;

public class XSSRemediator<T> implements Remediator<T> {
/** Remediator for XSS vulnerabilities. */
public final class XSSRemediator<T> implements Remediator<T> {

private final SearcherStrategyRemediator<T> searchStrategyRemediator;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import com.github.javaparser.StaticJavaParser;
import com.github.javaparser.ast.CompilationUnit;
import com.github.javaparser.printer.lexicalpreservation.LexicalPreservingPrinter;
import io.codemodder.CodemodFileScanningResult;
import io.codemodder.codetf.DetectorRule;
import io.codemodder.remediation.FixCandidateSearcher;
import io.codemodder.remediation.SearcherStrategyRemediator;
Expand Down Expand Up @@ -76,7 +77,55 @@ void should_be_fixed(String s) {
getWriter().write(Encode.forHtml(s));
}
}
"""));
"""),
Arguments.of(
"""
class Samples {
void should_be_fixed(String s) {
getWriter().write("<div>" + s);
}
}
""",
"""
import org.owasp.encoder.Encode;
class Samples {
void should_be_fixed(String s) {
getWriter().write("<div>" + Encode.forHtml(s));
}
}
"""),
Arguments.of(
"""
class Samples {
void should_be_fixed(String s) {
getWriter().write("<div>" + s + "</div>");
}
}
""",
"""
import org.owasp.encoder.Encode;
class Samples {
void should_be_fixed(String s) {
getWriter().write("<div>" + Encode.forHtml(s) + "</div>");
}
}
"""),
Arguments.of(
"""
class Samples {
void should_be_fixed(String s) {
getWriter().write(s + "</div>");
}
}
""",
"""
import org.owasp.encoder.Encode;
class Samples {
void should_be_fixed(String s) {
getWriter().write(Encode.forHtml(s) + "</div>");
}
}
"""));
}

@ParameterizedTest
Expand All @@ -85,7 +134,14 @@ void it_fixes_obvious_response_write_methods(final String beforeCode, final Stri
CompilationUnit cu = StaticJavaParser.parse(beforeCode);
LexicalPreservingPrinter.setup(cu);

XSSFinding finding = new XSSFinding("should_be_fixed", 3, null);
var result = scanAndFix(cu, 3);
assertThat(result.changes()).isNotEmpty();
String actualCode = LexicalPreservingPrinter.print(cu);
assertThat(actualCode).isEqualToIgnoringWhitespace(afterCode);
}

private CodemodFileScanningResult scanAndFix(final CompilationUnit cu, final int line) {
XSSFinding finding = new XSSFinding("should_be_fixed", line, null);
var remediator =
new SearcherStrategyRemediator.Builder<XSSFinding>()
.withSearcherStrategyPair(
Expand All @@ -94,18 +150,58 @@ void it_fixes_obvious_response_write_methods(final String beforeCode, final Stri
.build(),
fixer)
.build();
var result =
remediator.remediateAll(
cu,
"path",
rule,
List.of(finding),
XSSFinding::key,
XSSFinding::line,
x -> Optional.empty(),
x -> Optional.ofNullable(x.column()));
assertThat(result.changes().isEmpty()).isFalse();
String actualCode = LexicalPreservingPrinter.print(cu);
assertThat(actualCode).isEqualToIgnoringWhitespace(afterCode);
return remediator.remediateAll(
cu,
"path",
rule,
List.of(finding),
XSSFinding::key,
XSSFinding::line,
x -> Optional.empty(),
x -> Optional.ofNullable(x.column()));
}

@ParameterizedTest
@MethodSource("unfixableSamples")
void it_does_not_fix_unfixable_response_write_methods(final String beforeCode, final int line) {
CompilationUnit cu = StaticJavaParser.parse(beforeCode);
LexicalPreservingPrinter.setup(cu);
var result = scanAndFix(cu, line);
assertThat(result.changes()).isEmpty();
}

private static Stream<Arguments> unfixableSamples() {
return Stream.of(
// this is all string literals -- ignore
Arguments.of(
"""
class Samples {
void should_be_fixed(String s) {
getWriter().write("<div>" + "<b>" + "</div>");
}
}
""",
3),
// this is ambiguous which value to encode
Arguments.of(
"""
class Samples {
void should_be_fixed(String s) {
getWriter().write("<div>" + a + b + "</div>");
}
}
""",
3),
// this is the wrong line
Arguments.of(
"""
class Samples {
void should_be_fixed(String s) {
// extra line, right line is 4
getWriter().write("<div>" + a + "</div>");
}
}
""",
3));
}
}

0 comments on commit d9b49a0

Please sign in to comment.