From c13c3527feedfd179e2743c01e7381078d3ea404 Mon Sep 17 00:00:00 2001 From: HappenLee Date: Wed, 5 Feb 2025 15:34:16 +0800 Subject: [PATCH] [Refactor](UDF) Refactor the java udf code to reduce the unless code --- .../common/jni/utils/JavaUdfDataType.java | 97 +++++++++++-------- .../doris/common/jni/utils/UdfClassCache.java | 2 + .../org/apache/doris/udf/BaseExecutor.java | 50 +++++----- .../org/apache/doris/udf/UdafExecutor.java | 23 +---- .../org/apache/doris/udf/UdfExecutor.java | 33 ++----- 5 files changed, 92 insertions(+), 113 deletions(-) diff --git a/fe/be-java-extensions/java-common/src/main/java/org/apache/doris/common/jni/utils/JavaUdfDataType.java b/fe/be-java-extensions/java-common/src/main/java/org/apache/doris/common/jni/utils/JavaUdfDataType.java index 6077f713e8319d9..febbf30c691e7f1 100644 --- a/fe/be-java-extensions/java-common/src/main/java/org/apache/doris/common/jni/utils/JavaUdfDataType.java +++ b/fe/be-java-extensions/java-common/src/main/java/org/apache/doris/common/jni/utils/JavaUdfDataType.java @@ -29,7 +29,8 @@ import java.math.BigInteger; import java.net.InetAddress; import java.util.ArrayList; -import java.util.HashSet; +import java.util.HashMap; +import java.util.Map; import java.util.Set; // Data types that are supported as return or argument types in Java UDFs. @@ -63,32 +64,36 @@ public class JavaUdfDataType { public static final JavaUdfDataType MAP_TYPE = new JavaUdfDataType("MAP_TYPE", TPrimitiveType.MAP, 0); public static final JavaUdfDataType STRUCT_TYPE = new JavaUdfDataType("STRUCT_TYPE", TPrimitiveType.STRUCT, 0); - private static Set JavaUdfDataTypeSet = new HashSet<>(); + private static final Map javaUdfDataTypeMap = new HashMap<>(); + + public static void addJavaUdfDataType(JavaUdfDataType dataType) { + javaUdfDataTypeMap.put(dataType.getPrimitiveType(), dataType); + } static { - JavaUdfDataTypeSet.add(INVALID_TYPE); - JavaUdfDataTypeSet.add(BOOLEAN); - JavaUdfDataTypeSet.add(TINYINT); - JavaUdfDataTypeSet.add(SMALLINT); - JavaUdfDataTypeSet.add(INT); - JavaUdfDataTypeSet.add(BIGINT); - JavaUdfDataTypeSet.add(FLOAT); - JavaUdfDataTypeSet.add(DOUBLE); - JavaUdfDataTypeSet.add(STRING); - JavaUdfDataTypeSet.add(DATE); - JavaUdfDataTypeSet.add(DATETIME); - JavaUdfDataTypeSet.add(LARGEINT); - JavaUdfDataTypeSet.add(DECIMALV2); - JavaUdfDataTypeSet.add(DATEV2); - JavaUdfDataTypeSet.add(DATETIMEV2); - JavaUdfDataTypeSet.add(DECIMAL32); - JavaUdfDataTypeSet.add(DECIMAL64); - JavaUdfDataTypeSet.add(DECIMAL128); - JavaUdfDataTypeSet.add(ARRAY_TYPE); - JavaUdfDataTypeSet.add(MAP_TYPE); - JavaUdfDataTypeSet.add(STRUCT_TYPE); - JavaUdfDataTypeSet.add(IPV4); - JavaUdfDataTypeSet.add(IPV6); + addJavaUdfDataType(INVALID_TYPE); + addJavaUdfDataType(BOOLEAN); + addJavaUdfDataType(TINYINT); + addJavaUdfDataType(SMALLINT); + addJavaUdfDataType(INT); + addJavaUdfDataType(BIGINT); + addJavaUdfDataType(FLOAT); + addJavaUdfDataType(DOUBLE); + addJavaUdfDataType(STRING); + addJavaUdfDataType(DATE); + addJavaUdfDataType(DATETIME); + addJavaUdfDataType(LARGEINT); + addJavaUdfDataType(DECIMALV2); + addJavaUdfDataType(DATEV2); + addJavaUdfDataType(DATETIMEV2); + addJavaUdfDataType(DECIMAL32); + addJavaUdfDataType(DECIMAL64); + addJavaUdfDataType(DECIMAL128); + addJavaUdfDataType(ARRAY_TYPE); + addJavaUdfDataType(MAP_TYPE); + addJavaUdfDataType(STRUCT_TYPE); + addJavaUdfDataType(IPV4); + addJavaUdfDataType(IPV6); } private final String description; @@ -117,17 +122,33 @@ public JavaUdfDataType(JavaUdfDataType other) { @Override public String toString() { - return description; - } + StringBuilder res = new StringBuilder(); + res.append(description); + // TODO: the item/key/value type should be dispose in child class + if (getItemType() != null) { + res.append(" item: ").append(getItemType().toString()).append(" sql: ") + .append(getItemType().toSql()); + } + if (getKeyType() != null) { + res.append(" key: ").append(getKeyType().toString()).append(" sql: ") + .append(getKeyType().toSql()); + } + if (getValueType() != null) { + res.append(" value: ").append(getValueType().toString()).append(" sql: ") + .append(getValueType().toSql()); + } - public TPrimitiveType getPrimitiveType() { - return thriftType; + return res.toString(); } public int getLen() { return len; } + public TPrimitiveType getPrimitiveType() { + return thriftType; + } + public static Set getCandidateTypes(Class c) { if (c == boolean.class || c == Boolean.class) { return Sets.newHashSet(JavaUdfDataType.BOOLEAN); @@ -169,19 +190,14 @@ public static Set getCandidateTypes(Class c) { } public static boolean isSupported(Type t) { - for (JavaUdfDataType javaType : JavaUdfDataTypeSet) { - if (javaType == JavaUdfDataType.INVALID_TYPE) { - continue; - } - if (javaType.getPrimitiveType() == t.getPrimitiveType().toThrift()) { - return true; - } - } - if (t.getPrimitiveType().toThrift() == TPrimitiveType.VARCHAR - || t.getPrimitiveType().toThrift() == TPrimitiveType.CHAR) { + TPrimitiveType thriftType = t.getPrimitiveType().toThrift(); + // varchar and char are supported in java udf, type is String + if (thriftType == TPrimitiveType.VARCHAR + || thriftType == TPrimitiveType.CHAR) { return true; } - return false; + return !thriftType.equals(TPrimitiveType.INVALID_TYPE) + && javaUdfDataTypeMap.containsKey(thriftType); } public int getPrecision() { @@ -209,7 +225,6 @@ public void setItemType(Type type) throws InternalException { this.itemType = type; } else { if (!this.itemType.matchesType(type)) { - LOG.info("set error"); throw new InternalException("udf type not matches origin type :" + this.itemType.toSql() + " set type :" + type.toSql()); } diff --git a/fe/be-java-extensions/java-common/src/main/java/org/apache/doris/common/jni/utils/UdfClassCache.java b/fe/be-java-extensions/java-common/src/main/java/org/apache/doris/common/jni/utils/UdfClassCache.java index 4515e7052141448..696ef4ed0bb182b 100644 --- a/fe/be-java-extensions/java-common/src/main/java/org/apache/doris/common/jni/utils/UdfClassCache.java +++ b/fe/be-java-extensions/java-common/src/main/java/org/apache/doris/common/jni/utils/UdfClassCache.java @@ -38,4 +38,6 @@ public class UdfClassCache { public JavaUdfDataType retType; // the class type of the arguments in evaluate() method public Class[] argClass; + // The return type class of evaluate() method + public Class retClass; } diff --git a/fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/BaseExecutor.java b/fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/BaseExecutor.java index bef25f83da87fa2..3393e74234ce962 100644 --- a/fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/BaseExecutor.java +++ b/fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/BaseExecutor.java @@ -23,6 +23,7 @@ import org.apache.doris.common.exception.UdfRuntimeException; import org.apache.doris.common.jni.utils.JavaUdfDataType; import org.apache.doris.common.jni.vec.ColumnValueConverter; +import org.apache.doris.common.jni.vec.VectorTable; import org.apache.doris.thrift.TFunction; import org.apache.doris.thrift.TJavaUdfExecutorCtorParams; import org.apache.doris.thrift.TPrimitiveType; @@ -39,6 +40,7 @@ import java.time.LocalDateTime; import java.util.ArrayList; import java.util.HashMap; +import java.util.Map; import java.util.Map.Entry; public abstract class BaseExecutor { @@ -69,7 +71,9 @@ public abstract class BaseExecutor { protected JavaUdfDataType retType; protected Class[] argClass; protected MethodAccess methodAccess; + protected VectorTable outputTable = null; protected TFunction fn; + protected Class retClass; /** * Create a UdfExecutor, using parameters from a serialized thrift object. Used @@ -102,32 +106,8 @@ public String debugString() { StringBuilder res = new StringBuilder(); for (JavaUdfDataType type : argTypes) { res.append(type.toString()); - if (type.getItemType() != null) { - res.append(" item: ").append(type.getItemType().toString()).append(" sql: ") - .append(type.getItemType().toSql()); - } - if (type.getKeyType() != null) { - res.append(" key: ").append(type.getKeyType().toString()).append(" sql: ") - .append(type.getKeyType().toSql()); - } - if (type.getValueType() != null) { - res.append(" key: ").append(type.getValueType().toString()).append(" sql: ") - .append(type.getValueType().toSql()); - } } res.append(" return type: ").append(retType.toString()); - if (retType.getItemType() != null) { - res.append(" item: ").append(retType.getItemType().toString()).append(" sql: ") - .append(retType.getItemType().toSql()); - } - if (retType.getKeyType() != null) { - res.append(" key: ").append(retType.getKeyType().toString()).append(" sql: ") - .append(retType.getKeyType().toSql()); - } - if (retType.getValueType() != null) { - res.append(" key: ").append(retType.getValueType().toString()).append(" sql: ") - .append(retType.getValueType().toSql()); - } res.append(" methodAccess: ").append(methodAccess.toString()); res.append(" fn.toString(): ").append(fn.toString()); return res.toString(); @@ -150,6 +130,10 @@ public void close() { } } } + // Close the output table if it exists. + if (outputTable != null) { + outputTable.close(); + } // We are now un-usable (because the class loader has been // closed), so null out method_ and classLoader_. classLoader = null; @@ -330,4 +314,22 @@ protected ColumnValueConverter getOutputConverter(JavaUdfDataType returnType, Cl } return null; } + + // Add unified converter methods + protected Map getInputConverters(int numColumns, boolean isUdaf) { + Map converters = new HashMap<>(); + for (int j = 0; j < numColumns; ++j) { + // For UDAF, we need to offset by 1 since first arg is state + int argIndex = isUdaf ? j + 1 : j; + ColumnValueConverter converter = getInputConverter(argTypes[j].getPrimitiveType(), argClass[argIndex]); + if (converter != null) { + converters.put(j, converter); + } + } + return converters; + } + + protected ColumnValueConverter getOutputConverter() { + return getOutputConverter(retType, retClass); + } } diff --git a/fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/UdafExecutor.java b/fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/UdafExecutor.java index 31fb93ed8b6f3b4..629a67ba4a90dd7 100644 --- a/fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/UdafExecutor.java +++ b/fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/UdafExecutor.java @@ -53,10 +53,7 @@ public class UdafExecutor extends BaseExecutor { private HashMap allMethods; private HashMap stateObjMap; - private Class retClass; private int addIndex; - private VectorTable outputTable = null; - /** * Constructor to create an object. */ @@ -69,35 +66,17 @@ public UdafExecutor(byte[] thriftParams) throws Exception { */ @Override public void close() { - if (outputTable != null) { - outputTable.close(); - } super.close(); allMethods = null; stateObjMap = null; } - private Map getInputConverters(int numColumns) { - Map converters = new HashMap<>(); - for (int j = 0; j < numColumns; ++j) { - ColumnValueConverter converter = getInputConverter(argTypes[j].getPrimitiveType(), argClass[j + 1]); - if (converter != null) { - converters.put(j, converter); - } - } - return converters; - } - - private ColumnValueConverter getOutputConverter() { - return getOutputConverter(retType, retClass); - } - public void addBatch(boolean isSinglePlace, int rowStart, int rowEnd, long placeAddr, int offset, Map inputParams) throws UdfRuntimeException { try { VectorTable inputTable = VectorTable.createReadableTable(inputParams); Object[][] inputs = inputTable.getMaterializedData(rowStart, rowEnd, - getInputConverters(inputTable.getNumColumns())); + getInputConverters(inputTable.getNumColumns(), true)); if (isSinglePlace) { addBatchSingle(rowStart, rowEnd, placeAddr, inputs); } else { diff --git a/fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/UdfExecutor.java b/fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/UdfExecutor.java index 1b5bff1e7c11d1e..ca5dea62577f8e8 100644 --- a/fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/UdfExecutor.java +++ b/fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/UdfExecutor.java @@ -53,8 +53,6 @@ public class UdfExecutor extends BaseExecutor { private int evaluateIndex; - private VectorTable outputTable = null; - private boolean isStaticLoad = false; /** @@ -70,33 +68,16 @@ public UdfExecutor(byte[] thriftParams) throws Exception { */ @Override public void close() { - // inputTable is released by c++, only release outputTable - if (outputTable != null) { - outputTable.close(); - } // We are now un-usable (because the class loader has been // closed), so null out method_ and classLoader_. method = null; if (!isStaticLoad) { super.close(); + } else if (outputTable != null) { + outputTable.close(); } } - private Map getInputConverters(int numColumns) { - Map converters = new HashMap<>(); - for (int j = 0; j < numColumns; ++j) { - ColumnValueConverter converter = getInputConverter(argTypes[j].getPrimitiveType(), argClass[j]); - if (converter != null) { - converters.put(j, converter); - } - } - return converters; - } - - private ColumnValueConverter getOutputConverter() { - return getOutputConverter(retType, method.getReturnType()); - } - public long evaluate(Map inputParams, Map outputParams) throws UdfRuntimeException { try { VectorTable inputTable = VectorTable.createReadableTable(inputParams); @@ -112,7 +93,7 @@ public long evaluate(Map inputParams, Map output Object[] result = outputTable.getColumnType(0).isPrimitive() ? outputTable.getColumn(0).newObjectContainerArray(numRows) : (Object[]) Array.newInstance(method.getReturnType(), numRows); - Object[][] inputs = inputTable.getMaterializedData(getInputConverters(numColumns)); + Object[][] inputs = inputTable.getMaterializedData(getInputConverters(numColumns, false)); Object[] parameters = new Object[numColumns]; for (int i = 0; i < numRows; ++i) { for (int j = 0; j < numColumns; ++j) { @@ -216,16 +197,15 @@ private void checkAndCacheUdfClass(String className, UdfClassCache cache, Type f } else { cache.retType = returnType.second; } - Type keyType = cache.retType.getKeyType(); - Type valueType = cache.retType.getValueType(); Pair inputType = UdfUtils.setArgTypes(parameterTypes, cache.argClass, false); if (!inputType.first) { continue; } else { cache.argTypes = inputType.second; } - cache.retType.setKeyType(keyType); - cache.retType.setValueType(valueType); + if (cache.method != null) { + cache.retClass = cache.method.getReturnType(); + } return; } StringBuilder sb = new StringBuilder(); @@ -269,6 +249,7 @@ protected void init(TJavaUdfExecutorCtorParams request, String jarPath, Type fun evaluateIndex = cache.evaluateIndex; retType = cache.retType; argTypes = cache.argTypes; + retClass = cache.retClass; } catch (MalformedURLException e) { throw new UdfRuntimeException("Unable to load jar.", e); } catch (SecurityException e) {