Skip to content

Commit

Permalink
[Refactor](UDF) Refactor the java udf code to reduce the unless code
Browse files Browse the repository at this point in the history
  • Loading branch information
HappenLee committed Feb 5, 2025
1 parent e042e9f commit c13c352
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 113 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@
import java.math.BigInteger;
import java.net.InetAddress;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;

// Data types that are supported as return or argument types in Java UDFs.
Expand Down Expand Up @@ -63,32 +64,36 @@ public class JavaUdfDataType {
public static final JavaUdfDataType MAP_TYPE = new JavaUdfDataType("MAP_TYPE", TPrimitiveType.MAP, 0);
public static final JavaUdfDataType STRUCT_TYPE = new JavaUdfDataType("STRUCT_TYPE", TPrimitiveType.STRUCT, 0);

private static Set<JavaUdfDataType> JavaUdfDataTypeSet = new HashSet<>();
private static final Map<TPrimitiveType, JavaUdfDataType> javaUdfDataTypeMap = new HashMap<>();

public static void addJavaUdfDataType(JavaUdfDataType dataType) {
javaUdfDataTypeMap.put(dataType.getPrimitiveType(), dataType);
}

static {
JavaUdfDataTypeSet.add(INVALID_TYPE);
JavaUdfDataTypeSet.add(BOOLEAN);
JavaUdfDataTypeSet.add(TINYINT);
JavaUdfDataTypeSet.add(SMALLINT);
JavaUdfDataTypeSet.add(INT);
JavaUdfDataTypeSet.add(BIGINT);
JavaUdfDataTypeSet.add(FLOAT);
JavaUdfDataTypeSet.add(DOUBLE);
JavaUdfDataTypeSet.add(STRING);
JavaUdfDataTypeSet.add(DATE);
JavaUdfDataTypeSet.add(DATETIME);
JavaUdfDataTypeSet.add(LARGEINT);
JavaUdfDataTypeSet.add(DECIMALV2);
JavaUdfDataTypeSet.add(DATEV2);
JavaUdfDataTypeSet.add(DATETIMEV2);
JavaUdfDataTypeSet.add(DECIMAL32);
JavaUdfDataTypeSet.add(DECIMAL64);
JavaUdfDataTypeSet.add(DECIMAL128);
JavaUdfDataTypeSet.add(ARRAY_TYPE);
JavaUdfDataTypeSet.add(MAP_TYPE);
JavaUdfDataTypeSet.add(STRUCT_TYPE);
JavaUdfDataTypeSet.add(IPV4);
JavaUdfDataTypeSet.add(IPV6);
addJavaUdfDataType(INVALID_TYPE);
addJavaUdfDataType(BOOLEAN);
addJavaUdfDataType(TINYINT);
addJavaUdfDataType(SMALLINT);
addJavaUdfDataType(INT);
addJavaUdfDataType(BIGINT);
addJavaUdfDataType(FLOAT);
addJavaUdfDataType(DOUBLE);
addJavaUdfDataType(STRING);
addJavaUdfDataType(DATE);
addJavaUdfDataType(DATETIME);
addJavaUdfDataType(LARGEINT);
addJavaUdfDataType(DECIMALV2);
addJavaUdfDataType(DATEV2);
addJavaUdfDataType(DATETIMEV2);
addJavaUdfDataType(DECIMAL32);
addJavaUdfDataType(DECIMAL64);
addJavaUdfDataType(DECIMAL128);
addJavaUdfDataType(ARRAY_TYPE);
addJavaUdfDataType(MAP_TYPE);
addJavaUdfDataType(STRUCT_TYPE);
addJavaUdfDataType(IPV4);
addJavaUdfDataType(IPV6);
}

private final String description;
Expand Down Expand Up @@ -117,17 +122,33 @@ public JavaUdfDataType(JavaUdfDataType other) {

@Override
public String toString() {
return description;
}
StringBuilder res = new StringBuilder();
res.append(description);
// TODO: the item/key/value type should be dispose in child class
if (getItemType() != null) {
res.append(" item: ").append(getItemType().toString()).append(" sql: ")
.append(getItemType().toSql());
}
if (getKeyType() != null) {
res.append(" key: ").append(getKeyType().toString()).append(" sql: ")
.append(getKeyType().toSql());
}
if (getValueType() != null) {
res.append(" value: ").append(getValueType().toString()).append(" sql: ")
.append(getValueType().toSql());
}

public TPrimitiveType getPrimitiveType() {
return thriftType;
return res.toString();
}

public int getLen() {
return len;
}

public TPrimitiveType getPrimitiveType() {
return thriftType;
}

public static Set<JavaUdfDataType> getCandidateTypes(Class<?> c) {
if (c == boolean.class || c == Boolean.class) {
return Sets.newHashSet(JavaUdfDataType.BOOLEAN);
Expand Down Expand Up @@ -169,19 +190,14 @@ public static Set<JavaUdfDataType> getCandidateTypes(Class<?> c) {
}

public static boolean isSupported(Type t) {
for (JavaUdfDataType javaType : JavaUdfDataTypeSet) {
if (javaType == JavaUdfDataType.INVALID_TYPE) {
continue;
}
if (javaType.getPrimitiveType() == t.getPrimitiveType().toThrift()) {
return true;
}
}
if (t.getPrimitiveType().toThrift() == TPrimitiveType.VARCHAR
|| t.getPrimitiveType().toThrift() == TPrimitiveType.CHAR) {
TPrimitiveType thriftType = t.getPrimitiveType().toThrift();
// varchar and char are supported in java udf, type is String
if (thriftType == TPrimitiveType.VARCHAR
|| thriftType == TPrimitiveType.CHAR) {
return true;
}
return false;
return !thriftType.equals(TPrimitiveType.INVALID_TYPE)
&& javaUdfDataTypeMap.containsKey(thriftType);
}

public int getPrecision() {
Expand Down Expand Up @@ -209,7 +225,6 @@ public void setItemType(Type type) throws InternalException {
this.itemType = type;
} else {
if (!this.itemType.matchesType(type)) {
LOG.info("set error");
throw new InternalException("udf type not matches origin type :" + this.itemType.toSql()
+ " set type :" + type.toSql());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,6 @@ public class UdfClassCache {
public JavaUdfDataType retType;
// the class type of the arguments in evaluate() method
public Class[] argClass;
// The return type class of evaluate() method
public Class retClass;
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.apache.doris.common.exception.UdfRuntimeException;
import org.apache.doris.common.jni.utils.JavaUdfDataType;
import org.apache.doris.common.jni.vec.ColumnValueConverter;
import org.apache.doris.common.jni.vec.VectorTable;
import org.apache.doris.thrift.TFunction;
import org.apache.doris.thrift.TJavaUdfExecutorCtorParams;
import org.apache.doris.thrift.TPrimitiveType;
Expand All @@ -39,6 +40,7 @@
import java.time.LocalDateTime;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
import java.util.Map.Entry;

public abstract class BaseExecutor {
Expand Down Expand Up @@ -69,7 +71,9 @@ public abstract class BaseExecutor {
protected JavaUdfDataType retType;
protected Class[] argClass;
protected MethodAccess methodAccess;
protected VectorTable outputTable = null;
protected TFunction fn;
protected Class retClass;

/**
* Create a UdfExecutor, using parameters from a serialized thrift object. Used
Expand Down Expand Up @@ -102,32 +106,8 @@ public String debugString() {
StringBuilder res = new StringBuilder();
for (JavaUdfDataType type : argTypes) {
res.append(type.toString());
if (type.getItemType() != null) {
res.append(" item: ").append(type.getItemType().toString()).append(" sql: ")
.append(type.getItemType().toSql());
}
if (type.getKeyType() != null) {
res.append(" key: ").append(type.getKeyType().toString()).append(" sql: ")
.append(type.getKeyType().toSql());
}
if (type.getValueType() != null) {
res.append(" key: ").append(type.getValueType().toString()).append(" sql: ")
.append(type.getValueType().toSql());
}
}
res.append(" return type: ").append(retType.toString());
if (retType.getItemType() != null) {
res.append(" item: ").append(retType.getItemType().toString()).append(" sql: ")
.append(retType.getItemType().toSql());
}
if (retType.getKeyType() != null) {
res.append(" key: ").append(retType.getKeyType().toString()).append(" sql: ")
.append(retType.getKeyType().toSql());
}
if (retType.getValueType() != null) {
res.append(" key: ").append(retType.getValueType().toString()).append(" sql: ")
.append(retType.getValueType().toSql());
}
res.append(" methodAccess: ").append(methodAccess.toString());
res.append(" fn.toString(): ").append(fn.toString());
return res.toString();
Expand All @@ -150,6 +130,10 @@ public void close() {
}
}
}
// Close the output table if it exists.
if (outputTable != null) {
outputTable.close();
}
// We are now un-usable (because the class loader has been
// closed), so null out method_ and classLoader_.
classLoader = null;
Expand Down Expand Up @@ -330,4 +314,22 @@ protected ColumnValueConverter getOutputConverter(JavaUdfDataType returnType, Cl
}
return null;
}

// Add unified converter methods
protected Map<Integer, ColumnValueConverter> getInputConverters(int numColumns, boolean isUdaf) {
Map<Integer, ColumnValueConverter> converters = new HashMap<>();
for (int j = 0; j < numColumns; ++j) {
// For UDAF, we need to offset by 1 since first arg is state
int argIndex = isUdaf ? j + 1 : j;
ColumnValueConverter converter = getInputConverter(argTypes[j].getPrimitiveType(), argClass[argIndex]);
if (converter != null) {
converters.put(j, converter);
}
}
return converters;
}

protected ColumnValueConverter getOutputConverter() {
return getOutputConverter(retType, retClass);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,7 @@ public class UdafExecutor extends BaseExecutor {

private HashMap<String, Method> allMethods;
private HashMap<Long, Object> stateObjMap;
private Class retClass;
private int addIndex;
private VectorTable outputTable = null;

/**
* Constructor to create an object.
*/
Expand All @@ -69,35 +66,17 @@ public UdafExecutor(byte[] thriftParams) throws Exception {
*/
@Override
public void close() {
if (outputTable != null) {
outputTable.close();
}
super.close();
allMethods = null;
stateObjMap = null;
}

private Map<Integer, ColumnValueConverter> getInputConverters(int numColumns) {
Map<Integer, ColumnValueConverter> converters = new HashMap<>();
for (int j = 0; j < numColumns; ++j) {
ColumnValueConverter converter = getInputConverter(argTypes[j].getPrimitiveType(), argClass[j + 1]);
if (converter != null) {
converters.put(j, converter);
}
}
return converters;
}

private ColumnValueConverter getOutputConverter() {
return getOutputConverter(retType, retClass);
}

public void addBatch(boolean isSinglePlace, int rowStart, int rowEnd, long placeAddr, int offset,
Map<String, String> inputParams) throws UdfRuntimeException {
try {
VectorTable inputTable = VectorTable.createReadableTable(inputParams);
Object[][] inputs = inputTable.getMaterializedData(rowStart, rowEnd,
getInputConverters(inputTable.getNumColumns()));
getInputConverters(inputTable.getNumColumns(), true));
if (isSinglePlace) {
addBatchSingle(rowStart, rowEnd, placeAddr, inputs);
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,6 @@ public class UdfExecutor extends BaseExecutor {

private int evaluateIndex;

private VectorTable outputTable = null;

private boolean isStaticLoad = false;

/**
Expand All @@ -70,33 +68,16 @@ public UdfExecutor(byte[] thriftParams) throws Exception {
*/
@Override
public void close() {
// inputTable is released by c++, only release outputTable
if (outputTable != null) {
outputTable.close();
}
// We are now un-usable (because the class loader has been
// closed), so null out method_ and classLoader_.
method = null;
if (!isStaticLoad) {
super.close();
} else if (outputTable != null) {
outputTable.close();
}
}

private Map<Integer, ColumnValueConverter> getInputConverters(int numColumns) {
Map<Integer, ColumnValueConverter> converters = new HashMap<>();
for (int j = 0; j < numColumns; ++j) {
ColumnValueConverter converter = getInputConverter(argTypes[j].getPrimitiveType(), argClass[j]);
if (converter != null) {
converters.put(j, converter);
}
}
return converters;
}

private ColumnValueConverter getOutputConverter() {
return getOutputConverter(retType, method.getReturnType());
}

public long evaluate(Map<String, String> inputParams, Map<String, String> outputParams) throws UdfRuntimeException {
try {
VectorTable inputTable = VectorTable.createReadableTable(inputParams);
Expand All @@ -112,7 +93,7 @@ public long evaluate(Map<String, String> inputParams, Map<String, String> output
Object[] result = outputTable.getColumnType(0).isPrimitive()
? outputTable.getColumn(0).newObjectContainerArray(numRows)
: (Object[]) Array.newInstance(method.getReturnType(), numRows);
Object[][] inputs = inputTable.getMaterializedData(getInputConverters(numColumns));
Object[][] inputs = inputTable.getMaterializedData(getInputConverters(numColumns, false));
Object[] parameters = new Object[numColumns];
for (int i = 0; i < numRows; ++i) {
for (int j = 0; j < numColumns; ++j) {
Expand Down Expand Up @@ -216,16 +197,15 @@ private void checkAndCacheUdfClass(String className, UdfClassCache cache, Type f
} else {
cache.retType = returnType.second;
}
Type keyType = cache.retType.getKeyType();
Type valueType = cache.retType.getValueType();
Pair<Boolean, JavaUdfDataType[]> inputType = UdfUtils.setArgTypes(parameterTypes, cache.argClass, false);
if (!inputType.first) {
continue;
} else {
cache.argTypes = inputType.second;
}
cache.retType.setKeyType(keyType);
cache.retType.setValueType(valueType);
if (cache.method != null) {
cache.retClass = cache.method.getReturnType();
}
return;
}
StringBuilder sb = new StringBuilder();
Expand Down Expand Up @@ -269,6 +249,7 @@ protected void init(TJavaUdfExecutorCtorParams request, String jarPath, Type fun
evaluateIndex = cache.evaluateIndex;
retType = cache.retType;
argTypes = cache.argTypes;
retClass = cache.retClass;
} catch (MalformedURLException e) {
throw new UdfRuntimeException("Unable to load jar.", e);
} catch (SecurityException e) {
Expand Down

0 comments on commit c13c352

Please sign in to comment.