Skip to content

Commit

Permalink
JUnit 4 AssertArrayEquals/AssertNotEquals to AssertJ (#15)
Browse files Browse the repository at this point in the history
  • Loading branch information
tkvangorder authored Dec 10, 2020
1 parent 9e6d5b6 commit ca93760
Show file tree
Hide file tree
Showing 10 changed files with 1,164 additions and 69 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
/*
* Copyright 2020 the original author or authors.
* <p>
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* <p>
* https://www.apache.org/licenses/LICENSE-2.0
* <p>
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.openrewrite.java.testing.junitassertj;

import org.openrewrite.AutoConfigure;
import org.openrewrite.java.AutoFormat;
import org.openrewrite.java.JavaIsoRefactorVisitor;
import org.openrewrite.java.MethodMatcher;
import org.openrewrite.java.tree.*;

import java.util.List;

import static org.openrewrite.java.tree.MethodTypeBuilder.newMethodType;

/**
* This is a refactoring visitor that will convert JUnit-style assertArrayEquals() to assertJ's assertThat().containsExactly().
*
* This visitor will handle the following JUnit 5 method signatures:
*
* <PRE>
* Two parameter variants:
*
* assertArrayEquals(expected,actual) -> assertThat(actual).containsExactly(expected)
*
* Three parameter variant where the third argument is a String:
*
* assertArrayEquals(expected, actual, "message") -> assertThat(actual).as("message").containsExactly(expected)
*
* Three parameter variant where the third argument is a String supplier:
*
* assertArrayEquals(expected, actual, () -> "message") -> assertThat(actual).withFailureMessage("message").containsExactly(expected)
*
* Three parameter variant where args are all floating point numbers.
*
* assertArrayEquals(expected, actual, delta) -> assertThat(actual).containsExactly(expected, within(delta));
*
* Four parameter variant when comparing floating point numbers with a delta and a message:
*
* assertArrayEquals(expected, actual, delta, "message") -> assertThat(actual).withFailureMessage("message").containsExactly(expected, within(delta));
*
* </PRE>
*/
@AutoConfigure
public class AssertArrayEqualsToAssertThat extends JavaIsoRefactorVisitor {

private static final String JUNIT_QUALIFIED_ASSERTIONS_CLASS_NAME = "org.junit.jupiter.api.Assertions";

private static final String ASSERTJ_QUALIFIED_ASSERTIONS_CLASS_NAME = "org.assertj.core.api.Assertions";
private static final String ASSERTJ_ASSERT_THAT_METHOD_NAME = "assertThat";
private static final String ASSERTJ_WITHIN_METHOD_NAME = "within";

/**
* This matcher finds the junit methods that will be migrated by this visitor.
*/
private static final MethodMatcher JUNIT_ASSERT_EQUALS_MATCHER = new MethodMatcher(
JUNIT_QUALIFIED_ASSERTIONS_CLASS_NAME + " assertArrayEquals(..)"
);

private static final JavaType ASSERTJ_ASSERTIONS_WILDCARD_STATIC_IMPORT = newMethodType()
.declaringClass("org.assertj.core.api.Assertions")
.flags(Flag.Public, Flag.Static)
.name("*")
.build();

public AssertArrayEqualsToAssertThat() {
setCursoringOn();
}

@Override
public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method) {

J.MethodInvocation original = super.visitMethodInvocation(method);

if (!JUNIT_ASSERT_EQUALS_MATCHER.matches(method)) {
return original;
}

List<Expression> originalArgs = original.getArgs().getArgs();

Expression expected = originalArgs.get(0);
Expression actual = originalArgs.get(1);

J.MethodInvocation replacement;
if (originalArgs.size() == 2) {
//assertThat(actual).isEqualTo(expected)
replacement = assertSimple(actual, expected);
} else if (originalArgs.size() == 3 && !isFloatingPointType(originalArgs.get(2))) {
//assertThat(actual).as(message).isEqualTo(expected)
replacement = assertWithMessage(actual, expected, originalArgs.get(2));
} else if (originalArgs.size() == 3) {
//assert is using floating points with a delta and no message.
replacement = assertFloatingPointDelta(actual, expected, originalArgs.get(2));
maybeAddImport(ASSERTJ_QUALIFIED_ASSERTIONS_CLASS_NAME, ASSERTJ_WITHIN_METHOD_NAME);

} else {
//The assertEquals is using a floating point with a delta argument and a message.
replacement = assertFloatingPointDeltaWithMessage(actual, expected, originalArgs.get(2), originalArgs.get(3));
maybeAddImport(ASSERTJ_QUALIFIED_ASSERTIONS_CLASS_NAME, ASSERTJ_WITHIN_METHOD_NAME);
}

//Make sure there is a static import for "org.assertj.core.api.Assertions.assertThat"
maybeAddImport(ASSERTJ_QUALIFIED_ASSERTIONS_CLASS_NAME, ASSERTJ_ASSERT_THAT_METHOD_NAME);
//And if there are no longer references to the JUnit assertions class, we can remove the import.
maybeRemoveImport(JUNIT_QUALIFIED_ASSERTIONS_CLASS_NAME);

//Format the replacement method invocation in the context of where it is called.
andThen(new AutoFormat(replacement));
return replacement;
}

private J.MethodInvocation assertSimple(Expression actual, Expression expected) {

List<J.MethodInvocation> statements = treeBuilder.buildSnippet(getCursor(),
String.format("assertThat(%s).containsExactly(%s);", actual.printTrimmed(), expected.printTrimmed()),
ASSERTJ_ASSERTIONS_WILDCARD_STATIC_IMPORT
);
return statements.get(0);
}

private J.MethodInvocation assertWithMessage(Expression actual, Expression expected, Expression message) {

// In assertJ the "as" method has a more informative error message, but doesn't accept String suppliers
// so we're using "as" if the message is a string and "withFailMessage" if it is a supplier.
String messageAs = TypeUtils.isString(message.getType())?"as":"withFailMessage";

List<J.MethodInvocation> statements = treeBuilder.buildSnippet(getCursor(),
String.format("assertThat(%s).%s(%s).containsExactly(%s);",
actual.printTrimmed(), messageAs, message.printTrimmed(), expected.printTrimmed()),
ASSERTJ_ASSERTIONS_WILDCARD_STATIC_IMPORT
);
return statements.get(0);
}

private J.MethodInvocation assertFloatingPointDelta(Expression actual, Expression expected, Expression delta) {
List<J.MethodInvocation> statements = treeBuilder.buildSnippet(getCursor(),
String.format("assertThat(%s).containsExactly(%s, within(%s));",
actual.printTrimmed(), expected.printTrimmed(), delta.printTrimmed()),
ASSERTJ_ASSERTIONS_WILDCARD_STATIC_IMPORT
);
return statements.get(0);
}

private J.MethodInvocation assertFloatingPointDeltaWithMessage(Expression actual, Expression expected,
Expression delta, Expression message) {

//If the message is a string use "as", if it is a supplier use "withFailMessage"
String messageAs = TypeUtils.isString(message.getType())?"as":"withFailMessage";

List<J.MethodInvocation> statements = treeBuilder.buildSnippet(getCursor(),
String.format("assertThat(%s).%s(%s).containsExactly(%s, within(%s));",
actual.printTrimmed(), messageAs, message.printTrimmed(), expected.printTrimmed(), delta.printTrimmed()),
ASSERTJ_ASSERTIONS_WILDCARD_STATIC_IMPORT
);
return statements.get(0);
}

/**
* Returns true if the expression's type is either a primitive float/double or their object forms Float/Double
*
* @param expression The expression parsed from the original AST.
* @return true if the type is a floating point number.
*/
private boolean isFloatingPointType(Expression expression) {

JavaType.FullyQualified fullyQualified = TypeUtils.asFullyQualified(expression.getType());
if (fullyQualified != null) {
String typeName = fullyQualified.getFullyQualifiedName();
return (typeName.equals("java.lang.Double") || typeName.equals("java.lang.Float"));
}

JavaType.Primitive parameterType = TypeUtils.asPrimitive(expression.getType());
return parameterType == JavaType.Primitive.Double || parameterType == JavaType.Primitive.Float;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,17 +35,21 @@
*
* assertEquals(expected,actual) -> assertThat(actual).isEqualTo(expected)
*
* Three parameter variant where the third argument is either a String or String Supplier:
* Three parameter variant where the third argument is a String:
*
* assertEquals(expected, actual, "message") -> assertThat(actual).withFailureMessage("message").isEqualTo(expected)
* assertEquals(expected, actual, "message") -> assertThat(actual).as("message").isEqualTo(expected)
*
* Three parameter variant where the third argument is a String supplier:
*
* assertEquals(expected, actual, () -> "message") -> assertThat(actual).withFailureMessage("message").isEqualTo(expected)
*
* Three parameter variant where args are all floating point numbers.
*
* assertEquals(expected, actual, delta) -> assertThat(actual).isCloseTo(expected, within(delta));
*
* Four parameter variant when comparing floating point numbers with a delta and a message:
*
* assertEquals(expected, actual, delta) -> assertThat(actual).withFailureMessage("message").isCloseTo(expected, within(delta));
* assertEquals(expected, actual, delta, "message") -> assertThat(actual).withFailureMessage("message").isCloseTo(expected, within(delta));
*
* </PRE>
*/
Expand All @@ -59,7 +63,7 @@ public class AssertEqualsToAssertThat extends JavaIsoRefactorVisitor {
private static final String ASSERTJ_WITHIN_METHOD_NAME = "within";

/**
* This matcher uses a pointcut expression to find the matching junit methods that will be migrated by this visitor
* This matcher finds the junit methods that will be migrated by this visitor.
*/
private static final MethodMatcher JUNIT_ASSERT_EQUALS_MATCHER = new MethodMatcher(
JUNIT_QUALIFIED_ASSERTIONS_CLASS_NAME + " assertEquals(..)"
Expand All @@ -75,12 +79,6 @@ public AssertEqualsToAssertThat() {
setCursoringOn();
}

@Override
public J.CompilationUnit visitCompilationUnit(J.CompilationUnit cu) {
maybeRemoveImport(JUNIT_QUALIFIED_ASSERTIONS_CLASS_NAME);
return super.visitCompilationUnit(cu);
}

@Override
public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method) {

Expand All @@ -107,14 +105,15 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method) {
maybeAddImport(ASSERTJ_QUALIFIED_ASSERTIONS_CLASS_NAME, ASSERTJ_WITHIN_METHOD_NAME);

} else {
//The assertEquals is using a primitive floating point with a delta argument. (There may be an optional)
//fourth argument that contains the message.
//The assertEquals is using a floating point with a delta argument and a message.
replacement = assertFloatingPointDeltaWithMessage(actual, expected, originalArgs.get(2), originalArgs.get(3));
maybeAddImport(ASSERTJ_QUALIFIED_ASSERTIONS_CLASS_NAME, ASSERTJ_WITHIN_METHOD_NAME);
}

//Make sure there is a static import for "org.assertj.core.api.Assertions.assertThat"
maybeAddImport(ASSERTJ_QUALIFIED_ASSERTIONS_CLASS_NAME, ASSERTJ_ASSERT_THAT_METHOD_NAME);
//And if there are no longer references to the JUnit assertions class, we can remove the import.
maybeRemoveImport(JUNIT_QUALIFIED_ASSERTIONS_CLASS_NAME);

//Format the replacement method invocation in the context of where it is called.
andThen(new AutoFormat(replacement));
Expand All @@ -131,9 +130,14 @@ private J.MethodInvocation assertSimple(Expression actual, Expression expected)
}

private J.MethodInvocation assertWithMessage(Expression actual, Expression expected, Expression message) {

// In assertJ the "as" method has a more informative error message, but doesn't accept String suppliers
// so we're using "as" if the message is a string and "withFailMessage" if it is a supplier.
String messageAs = TypeUtils.isString(message.getType())?"as":"withFailMessage";

List<J.MethodInvocation> statements = treeBuilder.buildSnippet(getCursor(),
String.format("assertThat(%s).withFailMessage(%s).isEqualTo(%s);",
actual.printTrimmed(), message.printTrimmed(), expected.printTrimmed()),
String.format("assertThat(%s).%s(%s).isEqualTo(%s);",
actual.printTrimmed(), messageAs, message.printTrimmed(), expected.printTrimmed()),
ASSERTJ_ASSERTIONS_WILDCARD_STATIC_IMPORT
);
return statements.get(0);
Expand All @@ -150,21 +154,32 @@ private J.MethodInvocation assertFloatingPointDelta(Expression actual, Expressio

private J.MethodInvocation assertFloatingPointDeltaWithMessage(Expression actual, Expression expected,
Expression delta, Expression message) {

//If the message is a string use "as", if it is a supplier use "withFailMessage"
String messageAs = TypeUtils.isString(message.getType())?"as":"withFailMessage";

List<J.MethodInvocation> statements = treeBuilder.buildSnippet(getCursor(),
String.format("assertThat(%s).withFailMessage(%s).isCloseTo(%s, within(%s));",
actual.printTrimmed(), message.printTrimmed(), expected.printTrimmed(), delta.printTrimmed()),
String.format("assertThat(%s).%s(%s).isCloseTo(%s, within(%s));",
actual.printTrimmed(), messageAs, message.printTrimmed(), expected.printTrimmed(), delta.printTrimmed()),
ASSERTJ_ASSERTIONS_WILDCARD_STATIC_IMPORT
);
return statements.get(0);
}

/**
* Returns true if the expression's type is either a primitive float or double.
* Returns true if the expression's type is either a primitive float/double or their object forms Float/Double
*
* @param expression The expression parsed from the original AST.
* @return true if the type is a floating point number.
*/
private boolean isFloatingPointType(Expression expression) {

JavaType.FullyQualified fullyQualified = TypeUtils.asFullyQualified(expression.getType());
if (fullyQualified != null) {
String typeName = fullyQualified.getFullyQualifiedName();
return (typeName.equals("java.lang.Double") || typeName.equals("java.lang.Float"));
}

JavaType.Primitive parameterType = TypeUtils.asPrimitive(expression.getType());
return parameterType == JavaType.Primitive.Double || parameterType == JavaType.Primitive.Float;
}
Expand Down
Loading

0 comments on commit ca93760

Please sign in to comment.