Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[improve](udaf)support class cache for java-udaf #47619

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import com.esotericsoftware.reflectasm.MethodAccess;

import java.lang.reflect.Method;
import java.util.HashMap;

/**
* This class is used for caching the class of UDF.
Expand All @@ -28,16 +29,17 @@ public class UdfClassCache {
public Class<?> udfClass;
// the index of evaluate() method in the class
public MethodAccess methodAccess;
public int evaluateIndex;
// the method of evaluate() in udf
public Method method;
// the method of prepare() in udf
public Method prepareMethod;
// the argument and return's JavaUdfDataType of evaluate() method.
public JavaUdfDataType[] argTypes;
public JavaUdfDataType retType;
// the class type of the arguments in evaluate() method
public Class[] argClass;
// The return type class of evaluate() method
public JavaUdfDataType retType;
public Class retClass;

// all methods in the class for java-udf/ java-udaf
public HashMap<String, Method> allMethods;
// for java-udf index is evaluate method index
// for java-udaf index is add method index
public int methodIndex;
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,29 @@

import org.apache.doris.catalog.ArrayType;
import org.apache.doris.catalog.Type;
import org.apache.doris.common.classloader.ScannerLoader;
import org.apache.doris.common.exception.InternalException;
import org.apache.doris.common.exception.UdfRuntimeException;
import org.apache.doris.common.jni.utils.JavaUdfDataType;
import org.apache.doris.common.jni.utils.UdfClassCache;
import org.apache.doris.common.jni.utils.UdfUtils;
import org.apache.doris.common.jni.vec.ColumnValueConverter;
import org.apache.doris.common.jni.vec.VectorTable;
import org.apache.doris.thrift.TFunction;
import org.apache.doris.thrift.TJavaUdfExecutorCtorParams;
import org.apache.doris.thrift.TPrimitiveType;

import com.esotericsoftware.reflectasm.MethodAccess;
import com.google.common.base.Strings;
import org.apache.log4j.Logger;
import org.apache.thrift.TDeserializer;
import org.apache.thrift.TException;
import org.apache.thrift.protocol.TBinaryProtocol;

import java.io.FileNotFoundException;
import java.io.IOException;
import java.lang.reflect.Constructor;
import java.net.MalformedURLException;
import java.net.URLClassLoader;
import java.time.LocalDate;
import java.time.LocalDateTime;
Expand All @@ -44,26 +51,17 @@
import java.util.Map.Entry;

public abstract class BaseExecutor {
private static final Logger LOG = Logger.getLogger(BaseExecutor.class);


// Object to deserialize ctor params from BE.
protected static final TBinaryProtocol.Factory PROTOCOL_FACTORY = new TBinaryProtocol.Factory();

private static final Logger LOG = Logger.getLogger(BaseExecutor.class);
protected Object udf;
// setup by init() and cleared by close()
protected URLClassLoader classLoader;

// Return and argument types of the function inferred from the udf method
// signature.
// The JavaUdfDataType enum maps it to corresponding primitive type.
protected JavaUdfDataType[] argTypes;
protected JavaUdfDataType retType;
protected Class[] argClass;
protected MethodAccess methodAccess;
protected VectorTable outputTable = null;
protected UdfClassCache objCache;
protected TFunction fn;
protected Class retClass;
protected boolean isStaticLoad = false;
protected VectorTable outputTable = null;
String className;

/**
* Create a UdfExecutor, using parameters from a serialized thrift object. Used
Expand Down Expand Up @@ -94,17 +92,79 @@ public BaseExecutor(byte[] thriftParams) throws Exception {

public String debugString() {
StringBuilder res = new StringBuilder();
for (JavaUdfDataType type : argTypes) {
for (JavaUdfDataType type : objCache.argTypes) {
res.append(type.toString());
}
res.append(" return type: ").append(retType.toString());
res.append(" methodAccess: ").append(methodAccess.toString());
res.append(" return type: ").append(objCache.retType.toString());
res.append(" methodAccess: ").append(objCache.methodAccess.toString());
res.append(" fn.toString(): ").append(fn.toString());
return res.toString();
}

protected abstract void init(TJavaUdfExecutorCtorParams request, String jarPath,
Type funcRetType, Type... parameterTypes) throws UdfRuntimeException;
protected void init(TJavaUdfExecutorCtorParams request, String jarPath,
Type funcRetType, Type... parameterTypes) throws UdfRuntimeException {
try {
isStaticLoad = request.getFn().isSetIsStaticLoad() && request.getFn().is_static_load;
long expirationTime = 360L; // default is 6 hours
if (request.getFn().isSetExpirationTime()) {
expirationTime = request.getFn().getExpirationTime();
}
objCache = getClassCache(jarPath, request.getFn().getSignature(), expirationTime,
funcRetType, parameterTypes);
Constructor<?> ctor = objCache.udfClass.getConstructor();
udf = ctor.newInstance();
} catch (MalformedURLException e) {
throw new UdfRuntimeException("Unable to load jar.", e);
} catch (SecurityException e) {
throw new UdfRuntimeException("Unable to load function.", e);
} catch (ClassNotFoundException e) {
throw new UdfRuntimeException("Unable to find class.", e);
} catch (NoSuchMethodException e) {
throw new UdfRuntimeException(
"Unable to find constructor with no arguments.", e);
} catch (IllegalArgumentException e) {
throw new UdfRuntimeException(
"Unable to call UDF constructor with no arguments.", e);
} catch (Exception e) {
throw new UdfRuntimeException("Unable to call create UDF instance.", e);
}
}


public UdfClassCache getClassCache(String jarPath, String signature, long expirationTime,
Type funcRetType, Type... parameterTypes)
throws MalformedURLException, FileNotFoundException, ClassNotFoundException, InternalException,
UdfRuntimeException {
UdfClassCache cache = null;
if (isStaticLoad) {
cache = ScannerLoader.getUdfClassLoader(signature);
}
if (cache == null) {
ClassLoader loader;
if (Strings.isNullOrEmpty(jarPath)) {
// if jarPath is empty, which means the UDF jar is located in custom_lib
// and already be loaded when BE start.
// so here we use system class loader to load UDF class.
loader = ClassLoader.getSystemClassLoader();
} else {
ClassLoader parent = getClass().getClassLoader();
classLoader = UdfUtils.getClassLoader(jarPath, parent);
loader = classLoader;
}
cache = new UdfClassCache();
cache.allMethods = new HashMap<>();
cache.udfClass = Class.forName(className, true, loader);
cache.methodAccess = MethodAccess.get(cache.udfClass);
checkAndCacheUdfClass(cache, funcRetType, parameterTypes);
if (isStaticLoad) {
ScannerLoader.cacheClassLoader(signature, cache, expirationTime);
}
}
return cache;
}

protected abstract void checkAndCacheUdfClass(UdfClassCache cache, Type funcRetType, Type... parameterTypes)
throws InternalException, UdfRuntimeException;

/**
* Close the class loader we may have created.
Expand All @@ -127,7 +187,7 @@ public void close() {
// We are now un-usable (because the class loader has been
// closed), so null out method_ and classLoader_.
classLoader = null;
methodAccess = null;
objCache.methodAccess = null;
}

protected ColumnValueConverter getInputConverter(TPrimitiveType primitiveType, Class clz) {
Expand Down Expand Up @@ -311,7 +371,8 @@ protected Map<Integer, ColumnValueConverter> getInputConverters(int numColumns,
for (int j = 0; j < numColumns; ++j) {
// For UDAF, we need to offset by 1 since first arg is state
int argIndex = isUdaf ? j + 1 : j;
ColumnValueConverter converter = getInputConverter(argTypes[j].getPrimitiveType(), argClass[argIndex]);
ColumnValueConverter converter = getInputConverter(objCache.argTypes[j].getPrimitiveType(),
objCache.argClass[argIndex]);
if (converter != null) {
converters.put(j, converter);
}
Expand All @@ -320,6 +381,6 @@ protected Map<Integer, ColumnValueConverter> getInputConverters(int numColumns,
}

protected ColumnValueConverter getOutputConverter() {
return getOutputConverter(retType, retClass);
return getOutputConverter(objCache.retType, objCache.retClass);
}
}
Loading
Loading