diff --git a/src/main/java/org/openrewrite/hibernate/MigrateBooleanMappings.java b/src/main/java/org/openrewrite/hibernate/MigrateBooleanMappings.java
index 6702395..07104c1 100644
--- a/src/main/java/org/openrewrite/hibernate/MigrateBooleanMappings.java
+++ b/src/main/java/org/openrewrite/hibernate/MigrateBooleanMappings.java
@@ -94,7 +94,7 @@ public J.Annotation visitAnnotation(J.Annotation annotation, ExecutionContext ct
String converterFQN = String.format("org.hibernate.type.%s", converterName);
ann = JavaTemplate.builder(String.format("@Convert(converter = %s.class)", converterName))
- .javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, "hibernate-core", "jakarta.persistence-api"))
+ .javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, "hibernate-core-6+", "jakarta.persistence-api"))
.imports(converterFQN, "jakarta.persistence.Convert")
.contextSensitive()
.build().apply(getCursor(), ann.getCoordinates().replace());
diff --git a/src/main/java/org/openrewrite/hibernate/MigrateUserType.java b/src/main/java/org/openrewrite/hibernate/MigrateUserType.java
new file mode 100644
index 0000000..db9804c
--- /dev/null
+++ b/src/main/java/org/openrewrite/hibernate/MigrateUserType.java
@@ -0,0 +1,251 @@
+/*
+ * Copyright 2024 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.openrewrite.hibernate;
+
+import org.jspecify.annotations.Nullable;
+import org.openrewrite.ExecutionContext;
+import org.openrewrite.Preconditions;
+import org.openrewrite.Recipe;
+import org.openrewrite.TreeVisitor;
+import org.openrewrite.internal.ListUtils;
+import org.openrewrite.java.*;
+import org.openrewrite.java.search.FindImplementations;
+import org.openrewrite.java.search.FindMethodDeclaration;
+import org.openrewrite.java.tree.*;
+import org.openrewrite.marker.Markers;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+import java.util.Optional;
+import java.util.concurrent.atomic.AtomicReference;
+
+import static org.openrewrite.Tree.randomId;
+
+public class MigrateUserType extends Recipe {
+
+ private static final String USER_TYPE = "org.hibernate.usertype.UserType";
+ private static final MethodMatcher ASSEMBLE = new MethodMatcher("* assemble(java.io.Serializable, java.lang.Object)");
+ private static final MethodMatcher DEEP_COPY = new MethodMatcher("* deepCopy(java.lang.Object)");
+ private static final MethodMatcher DISASSEMBLE = new MethodMatcher("* disassemble(java.lang.Object)");
+ private static final MethodMatcher EQUALS = new MethodMatcher("* equals(java.lang.Object, java.lang.Object)");
+ private static final MethodMatcher HASHCODE = new MethodMatcher("* hashCode(java.lang.Object)");
+ private static final MethodMatcher NULL_SAFE_GET_STRING_ARRAY = new MethodMatcher("* nullSafeGet(java.sql.ResultSet, java.lang.String[], org.hibernate.engine.spi.SharedSessionContractImplementor, java.lang.Object)");
+ private static final MethodMatcher NULL_SAFE_SET = new MethodMatcher("* nullSafeSet(java.sql.PreparedStatement, java.lang.Object, int, org.hibernate.engine.spi.SharedSessionContractImplementor)");
+ private static final MethodMatcher NULL_SAFE_GET_INT = new MethodMatcher("* nullSafeGet(java.sql.ResultSet, int, org.hibernate.engine.spi.SharedSessionContractImplementor, java.lang.Object)");
+ private static final MethodMatcher REPLACE = new MethodMatcher("* replace(java.lang.Object, java.lang.Object, java.lang.Object)");
+ private static final MethodMatcher RESULT_SET_STRING_PARAM = new MethodMatcher("java.sql.ResultSet *(java.lang.String)");
+ private static final MethodMatcher RETURNED_CLASS = new MethodMatcher("* returnedClass()");
+ private static final MethodMatcher SQL_TYPES = new MethodMatcher("* sqlTypes()");
+
+ @Override
+ public String getDisplayName() {
+ return "Migrate `UserType` to Hibernate 6";
+ }
+
+ @Override
+ public String getDescription() {
+ return "With Hibernate 6 the `UserType` interface received a type parameter making it more strictly typed. " +
+ "This recipe applies the changes required to adhere to this change.";
+ }
+
+ @Override
+ public TreeVisitor, ExecutionContext> getVisitor() {
+ return Preconditions.check(Preconditions.and(
+ new FindImplementations(USER_TYPE).getVisitor(),
+ // This method only exists on the Hibernate 6 variant of UserType, so as a precondition this shouldn't exist
+ Preconditions.not(new FindMethodDeclaration("* getSqlType()", true).getVisitor())
+ ), new JavaVisitor() {
+ @Override
+ public J visitClassDeclaration(J.ClassDeclaration classDecl, ExecutionContext ctx) {
+ J.ClassDeclaration cd = classDecl;
+ J.FieldAccess parameterizedType = getReturnedClass(cd);
+ cd = cd.withImplements(ListUtils.map(cd.getImplements(), impl -> {
+ if (TypeUtils.isAssignableTo(USER_TYPE, impl.getType()) && parameterizedType != null) {
+ return TypeTree.build("UserType<" + parameterizedType.getTarget() + ">").withType(JavaType.buildType(USER_TYPE)).withPrefix(Space.SINGLE_SPACE);
+ }
+ return impl;
+ }));
+ if (parameterizedType != null) {
+ getCursor().putMessage("parameterizedType", parameterizedType);
+ }
+ return super.visitClassDeclaration(cd, ctx);
+ }
+
+ @SuppressWarnings("ConstantConditions")
+ private J.@Nullable FieldAccess getReturnedClass(J.ClassDeclaration cd) {
+ AtomicReference reference = new AtomicReference<>();
+ new JavaIsoVisitor>() {
+ @Override
+ public J.MethodDeclaration visitMethodDeclaration(J.MethodDeclaration method, AtomicReference ref) {
+ // Only visit top level method returnedClass
+ return RETURNED_CLASS.matches(method, cd) ? super.visitMethodDeclaration(method, ref) : method;
+ }
+
+ @Override
+ public J.Return visitReturn(J.Return _return, AtomicReference ref) {
+ ref.set((J.FieldAccess) _return.getExpression());
+ return _return;
+ }
+ }.visitNonNull(cd, reference);
+ return reference.get();
+ }
+
+ @Override
+ public J visitMethodDeclaration(J.MethodDeclaration method, ExecutionContext ctx) {
+ J.MethodDeclaration md = method;
+ J.ClassDeclaration cd = getCursor().firstEnclosing(J.ClassDeclaration.class);
+ J.FieldAccess parameterizedType = getCursor().getNearestMessage("parameterizedType");
+ if (cd == null || parameterizedType == null) {
+ return md;
+ }
+ if (SQL_TYPES.matches(md, cd)) {
+ if (md.getBody() != null) {
+ Optional ret = md.getBody().getStatements().stream().filter(J.Return.class::isInstance).map(J.Return.class::cast).findFirst();
+ if (ret.isPresent()) {
+ if (ret.get().getExpression() instanceof J.NewArray) {
+ J.NewArray newArray = (J.NewArray) ret.get().getExpression();
+ if (newArray.getInitializer() != null) {
+ String template = "@Override\n" +
+ "public int getSqlType() {\n" +
+ " return #{any()};\n" +
+ "}";
+ md = JavaTemplate.builder(template)
+ .javaParser(JavaParser.fromJavaVersion())
+ .build()
+ .apply(getCursor(), md.getCoordinates().replace(), newArray.getInitializer().get(0)).withId(md.getId());
+ }
+ }
+
+ }
+ }
+ } else if (RETURNED_CLASS.matches(md, cd)) {
+ md = md.withReturnTypeExpression(TypeTree.build("Class<" + parameterizedType.getTarget() + ">"));
+ if (md.getReturnTypeExpression() != null) {
+ md = md.withPrefix(md.getReturnTypeExpression().getPrefix());
+ }
+ } else if (EQUALS.matches(md, cd)) {
+ md = changeParameterTypes(md, Arrays.asList(0, 1), parameterizedType);
+ } else if (HASHCODE.matches(md, cd)) {
+ md = changeParameterTypes(md, Collections.singletonList(0), parameterizedType);
+ } else if (NULL_SAFE_GET_STRING_ARRAY.matches(md, cd)) {
+ String template = "@Override\n" +
+ "public BigDecimal nullSafeGet(ResultSet rs, int position, SharedSessionContractImplementor session, Object owner) throws SQLException {\n" +
+ "}";
+ J.MethodDeclaration updatedParam = JavaTemplate.builder(template)
+ .javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, "hibernate-core"))
+ .imports("java.math.BigDecimal", "java.sql.ResultSet", "java.sql.SQLException", "org.hibernate.engine.spi.SharedSessionContractImplementor")
+ .build()
+ .apply(getCursor(), md.getCoordinates().replace());
+ md = updatedParam.withId(md.getId()).withBody(md.getBody());
+ } else if (NULL_SAFE_SET.matches(md, cd)) {
+ md = changeParameterTypes(md, Collections.singletonList(1), parameterizedType);
+ } else if (DEEP_COPY.matches(md, cd)) {
+ md = md.withReturnTypeExpression(parameterizedType.getTarget().withPrefix(Space.SINGLE_SPACE));
+ if (md.getReturnTypeExpression() != null) {
+ md = md.withPrefix(md.getReturnTypeExpression().getPrefix());
+ }
+ md = changeParameterTypes(md, Collections.singletonList(0), parameterizedType);
+ } else if (DISASSEMBLE.matches(md, cd)) {
+ md = changeParameterTypes(md, Collections.singletonList(0), parameterizedType);
+ if (md.getBody() != null) {
+ md = md.withBody(md.getBody().withStatements(ListUtils.map(md.getBody().getStatements(), stmt -> {
+ if (stmt instanceof J.Return) {
+ J.Return r = (J.Return) stmt;
+ if (r.getExpression() instanceof J.TypeCast) {
+ J.TypeCast tc = (J.TypeCast) r.getExpression();
+ if (TypeUtils.isOfType(parameterizedType.getTarget().getType(), tc.getClazz().getType())) {
+ return r.withExpression(tc.getExpression());
+ }
+ }
+ }
+ return stmt;
+ })));
+ }
+ } else if (ASSEMBLE.matches(md, cd)) {
+ md = md.withReturnTypeExpression(parameterizedType.getTarget().withPrefix(Space.SINGLE_SPACE));
+ if (md.getReturnTypeExpression() != null) {
+ md = md.withPrefix(md.getReturnTypeExpression().getPrefix());
+ }
+ if (md.getBody() != null) {
+ md = md.withBody(md.getBody().withStatements(ListUtils.map(md.getBody().getStatements(), stmt -> {
+ if (stmt instanceof J.Return) {
+ J.Return r = (J.Return) stmt;
+ if (r.getExpression() != null && !TypeUtils.isOfType(parameterizedType.getTarget().getType(), r.getExpression().getType())) {
+ return r.withExpression(new J.TypeCast(randomId(), Space.EMPTY, Markers.EMPTY, new J.ControlParentheses<>(randomId(), Space.EMPTY, Markers.EMPTY,
+ new JRightPadded<>(TypeTree.build("BigDecimal").withType(parameterizedType.getTarget().getType()), Space.EMPTY, Markers.EMPTY)), r.getExpression()));
+ }
+ }
+ return stmt;
+ })));
+ }
+ } else if (REPLACE.matches(md, cd)) {
+ md = md.withReturnTypeExpression(parameterizedType.getTarget().withPrefix(Space.SINGLE_SPACE));
+ if (md.getReturnTypeExpression() != null) {
+ md = md.withPrefix(md.getReturnTypeExpression().getPrefix());
+ }
+ md = changeParameterTypes(md, Arrays.asList(0, 1), parameterizedType);
+ }
+ updateCursor(md);
+ md = (J.MethodDeclaration) super.visitMethodDeclaration(md, ctx);
+ return maybeAutoFormat(method, md, ctx);
+ }
+
+ private J.MethodDeclaration changeParameterTypes(J.MethodDeclaration md, List paramIndexes, J.FieldAccess parameterizedType) {
+ if (md.getMethodType() != null) {
+ JavaType.Method met = md.getMethodType().withParameterTypes(ListUtils.map(md.getMethodType().getParameterTypes(),
+ (index, type) -> {
+ if (paramIndexes.contains(index)) {
+ type = TypeUtils.isOfType(JavaType.buildType("java.lang.Object"), type) ? parameterizedType.getTarget().getType() : type;
+ }
+ return type;
+ }));
+ return md.withParameters(ListUtils.map(md.getParameters(), (index, param) -> {
+ if (param instanceof J.VariableDeclarations && paramIndexes.contains(index)) {
+ param = ((J.VariableDeclarations) param)
+ .withType(parameterizedType.getTarget().getType()).withTypeExpression((TypeTree) parameterizedType.getTarget())
+ .withVariables(ListUtils.map(((J.VariableDeclarations) param).getVariables(), var -> {
+ var = var.withType(parameterizedType.getTarget().getType());
+ if (var.getVariableType() != null && parameterizedType.getTarget().getType() != null) {
+ var = var.withVariableType(var.getVariableType().withType(parameterizedType.getTarget().getType()).withOwner(met));
+ }
+ return var;
+ }));
+ }
+ return param;
+ }));
+ }
+ return md;
+ }
+
+ @Override
+ public J visitMethodInvocation(J.MethodInvocation method, ExecutionContext ctx) {
+ J.MethodInvocation mi = (J.MethodInvocation) super.visitMethodInvocation(method, ctx);
+ if (RESULT_SET_STRING_PARAM.matches(mi)) {
+ J.MethodDeclaration md = getCursor().firstEnclosing(J.MethodDeclaration.class);
+ J.ClassDeclaration cd = getCursor().firstEnclosing(J.ClassDeclaration.class);
+ if (md != null && cd != null && NULL_SAFE_GET_INT.matches(md, cd)) {
+ mi = mi.withArguments(Collections.singletonList(((J.VariableDeclarations) md.getParameters().get(1)).getVariables().get(0).getName()));
+ if (mi.getMethodType() != null) {
+ mi = mi.withMethodType(mi.getMethodType().withParameterTypes(Collections.singletonList(JavaType.buildType("int"))));
+ }
+ }
+ }
+ return mi;
+ }
+ });
+ }
+}
diff --git a/src/test/java/org/openrewrite/hibernate/MigrateBooleanMappingsTest.java b/src/test/java/org/openrewrite/hibernate/MigrateBooleanMappingsTest.java
index ee61796..c0d3d27 100644
--- a/src/test/java/org/openrewrite/hibernate/MigrateBooleanMappingsTest.java
+++ b/src/test/java/org/openrewrite/hibernate/MigrateBooleanMappingsTest.java
@@ -32,7 +32,7 @@ class MigrateBooleanMappingsTest implements RewriteTest {
public void defaults(RecipeSpec spec) {
spec.recipe(new MigrateBooleanMappings())
.parser(JavaParser.fromJavaVersion()
- .classpathFromResources(new InMemoryExecutionContext(), "hibernate-core", "jakarta.persistence-api")
+ .classpathFromResources(new InMemoryExecutionContext(), "hibernate-core-6+", "jakarta.persistence-api")
);
}
diff --git a/src/test/java/org/openrewrite/hibernate/MigrateUserTypeTest.java b/src/test/java/org/openrewrite/hibernate/MigrateUserTypeTest.java
new file mode 100644
index 0000000..b86049f
--- /dev/null
+++ b/src/test/java/org/openrewrite/hibernate/MigrateUserTypeTest.java
@@ -0,0 +1,198 @@
+/*
+ * Copyright 2024 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.openrewrite.hibernate;
+
+import org.junit.jupiter.api.Test;
+import org.openrewrite.DocumentExample;
+import org.openrewrite.InMemoryExecutionContext;
+import org.openrewrite.java.JavaParser;
+import org.openrewrite.test.RecipeSpec;
+import org.openrewrite.test.RewriteTest;
+
+import static org.openrewrite.java.Assertions.java;
+
+class MigrateUserTypeTest implements RewriteTest {
+
+ @Override
+ public void defaults(RecipeSpec spec) {
+ spec.recipe(new MigrateUserType())
+ .parser(JavaParser.fromJavaVersion()
+ .classpathFromResources(new InMemoryExecutionContext(), "hibernate-core-6")
+ );
+ }
+
+ @Test
+ @DocumentExample
+ void shouldMigrateUserType() {
+ //language=java
+ rewriteRun(
+ java(
+ """
+ import org.hibernate.HibernateException;
+ import org.hibernate.engine.spi.SharedSessionContractImplementor;
+ import org.hibernate.usertype.UserType;
+
+ import java.io.Serializable;
+ import java.math.BigDecimal;
+ import java.sql.PreparedStatement;
+ import java.sql.ResultSet;
+ import java.sql.SQLException;
+ import java.sql.Types;
+ import java.util.Objects;
+
+ public class BigDecimalAsString implements UserType {
+
+ @Override
+ public int[] sqlTypes() {
+ return new int[]{Types.VARCHAR};
+ }
+
+ @Override
+ public Class returnedClass() {
+ return BigDecimal.class;
+ }
+
+ @Override
+ public boolean equals(Object x, Object y) {
+ return Objects.equals(x, y);
+ }
+
+ @Override
+ public int hashCode(Object x) {
+ return Objects.hashCode(x);
+ }
+
+ @Override
+ public Object nullSafeGet(ResultSet rs, String[] names, SharedSessionContractImplementor session, Object owner) throws SQLException {
+ String string = rs.getString(names[0]);
+ return string == null || rs.wasNull() ? null : new BigDecimal(string);
+ }
+
+ @Override
+ public void nullSafeSet(PreparedStatement st, Object value, int index, SharedSessionContractImplementor session) throws SQLException {
+ if (value == null) {
+ st.setNull(index, Types.VARCHAR);
+ } else {
+ st.setString(index, value.toString());
+ }
+ }
+
+ @Override
+ public Object deepCopy(Object value1) {
+ return value1;
+ }
+
+ @Override
+ public boolean isMutable() {
+ return false;
+ }
+
+ @Override
+ public Serializable disassemble(Object value) {
+ return (BigDecimal) value;
+ }
+
+ @Override
+ public Object assemble(Serializable cached, Object owner) {
+ return cached;
+ }
+
+ @Override
+ public Object replace(Object original, Object target, Object owner) {
+ return original;
+ }
+ }
+ """,
+ """
+ import org.hibernate.HibernateException;
+ import org.hibernate.engine.spi.SharedSessionContractImplementor;
+ import org.hibernate.usertype.UserType;
+
+ import java.io.Serializable;
+ import java.math.BigDecimal;
+ import java.sql.PreparedStatement;
+ import java.sql.ResultSet;
+ import java.sql.SQLException;
+ import java.sql.Types;
+ import java.util.Objects;
+
+ public class BigDecimalAsString implements UserType {
+
+ @Override
+ public int getSqlType() {
+ return Types.VARCHAR;
+ }
+
+ @Override
+ public Class returnedClass() {
+ return BigDecimal.class;
+ }
+
+ @Override
+ public boolean equals(BigDecimal x, BigDecimal y) {
+ return Objects.equals(x, y);
+ }
+
+ @Override
+ public int hashCode(BigDecimal x) {
+ return Objects.hashCode(x);
+ }
+
+ @Override
+ public BigDecimal nullSafeGet(ResultSet rs, int position, SharedSessionContractImplementor session, Object owner) throws SQLException {
+ String string = rs.getString(position);
+ return string == null || rs.wasNull() ? null : new BigDecimal(string);
+ }
+
+ @Override
+ public void nullSafeSet(PreparedStatement st, BigDecimal value, int index, SharedSessionContractImplementor session) throws SQLException {
+ if (value == null) {
+ st.setNull(index, Types.VARCHAR);
+ } else {
+ st.setString(index, value.toString());
+ }
+ }
+
+ @Override
+ public BigDecimal deepCopy(BigDecimal value1) {
+ return value1;
+ }
+
+ @Override
+ public boolean isMutable() {
+ return false;
+ }
+
+ @Override
+ public Serializable disassemble(BigDecimal value) {
+ return value;
+ }
+
+ @Override
+ public BigDecimal assemble(Serializable cached, Object owner) {
+ return (BigDecimal) cached;
+ }
+
+ @Override
+ public BigDecimal replace(BigDecimal original, BigDecimal target, Object owner) {
+ return original;
+ }
+ }
+ """
+ )
+ );
+ }
+}
diff --git a/src/test/java/org/openrewrite/hibernate/ReplaceLazyCollectionAnnotationTest.java b/src/test/java/org/openrewrite/hibernate/ReplaceLazyCollectionAnnotationTest.java
index db77674..86f75f4 100644
--- a/src/test/java/org/openrewrite/hibernate/ReplaceLazyCollectionAnnotationTest.java
+++ b/src/test/java/org/openrewrite/hibernate/ReplaceLazyCollectionAnnotationTest.java
@@ -31,7 +31,7 @@ class ReplaceLazyCollectionAnnotationTest implements RewriteTest {
public void defaults(RecipeSpec spec) {
spec.recipe(new ReplaceLazyCollectionAnnotation())
.parser(JavaParser.fromJavaVersion()
- .classpathFromResources(new InMemoryExecutionContext(), "hibernate-core", "jakarta.persistence-api")
+ .classpathFromResources(new InMemoryExecutionContext(), "hibernate-core-6+", "jakarta.persistence-api")
);
}
diff --git a/src/test/java/org/openrewrite/hibernate/TypeAnnotationParameterTest.java b/src/test/java/org/openrewrite/hibernate/TypeAnnotationParameterTest.java
index 65c8710..f468142 100644
--- a/src/test/java/org/openrewrite/hibernate/TypeAnnotationParameterTest.java
+++ b/src/test/java/org/openrewrite/hibernate/TypeAnnotationParameterTest.java
@@ -27,7 +27,7 @@
class TypeAnnotationParameterTest implements RewriteTest {
@Override
public void defaults(RecipeSpec spec) {
- spec.recipe(new TypeAnnotationParameter()).parser(JavaParser.fromJavaVersion().classpathFromResources(new InMemoryExecutionContext(), "hibernate-core"));
+ spec.recipe(new TypeAnnotationParameter()).parser(JavaParser.fromJavaVersion().classpathFromResources(new InMemoryExecutionContext(), "hibernate-core-6+"));
}
@DocumentExample