Skip to content

Commit

Permalink
branch-2.1: [fix](jdbc catalog) Change BE jdbc Driver loading to Java…
Browse files Browse the repository at this point in the history
… code (apache#48002)

cherry-pick from (apache#46912)
  • Loading branch information
zy-kkk authored Feb 21, 2025
1 parent 0fe6c17 commit 469bc77
Show file tree
Hide file tree
Showing 9 changed files with 153 additions and 38 deletions.
20 changes: 5 additions & 15 deletions be/src/runtime/user_function_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -272,12 +272,10 @@ Status UserFunctionCache::_download_lib(const std::string& url,
return Status::InternalError("fail to open file");
}

std::string real_url = _get_real_url(url);

Md5Digest digest;
HttpClient client;
int64_t file_size = 0;
RETURN_IF_ERROR(client.init(real_url));
RETURN_IF_ERROR(client.init(url));
Status status;
auto download_cb = [&status, &tmp_file, &fp, &digest, &file_size](const void* data,
size_t length) {
Expand All @@ -297,11 +295,10 @@ Status UserFunctionCache::_download_lib(const std::string& url,
digest.digest();
if (!iequal(digest.hex(), entry->checksum)) {
fmt::memory_buffer error_msg;
fmt::format_to(
error_msg,
" The checksum is not equal of {} ({}). The init info of first create entry is:"
"{} But download file check_sum is: {}, file_size is: {}.",
url, real_url, entry->debug_string(), digest.hex(), file_size);
fmt::format_to(error_msg,
" The checksum is not equal of {}. The init info of first create entry is:"
"{} But download file check_sum is: {}, file_size is: {}.",
url, entry->debug_string(), digest.hex(), file_size);
std::string error(fmt::to_string(error_msg));
LOG(WARNING) << error;
return Status::InternalError(error);
Expand All @@ -323,13 +320,6 @@ Status UserFunctionCache::_download_lib(const std::string& url,
return Status::OK();
}

std::string UserFunctionCache::_get_real_url(const std::string& url) {
if (url.find(":/") == std::string::npos) {
return "file://" + config::jdbc_drivers_dir + "/" + url;
}
return url;
}

std::string UserFunctionCache::_get_file_name_from_url(const std::string& url) const {
std::string file_name;
size_t last_slash_pos = url.find_last_of('/');
Expand Down
1 change: 0 additions & 1 deletion be/src/runtime/user_function_cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ class UserFunctionCache {
const std::string& file_name);
void _destroy_cache_entry(std::shared_ptr<UserFunctionCacheEntry> entry);

std::string _get_real_url(const std::string& url);
std::string _get_file_name_from_url(const std::string& url) const;
std::vector<std::string> _split_string_by_checksum(const std::string& file);

Expand Down
29 changes: 11 additions & 18 deletions be/src/vec/exec/vjdbc_connector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,23 +119,7 @@ Status JdbcConnector::open(RuntimeState* state, bool read) {
// Add a scoped cleanup jni reference object. This cleans up local refs made below.
JniLocalFrame jni_frame;
{
std::string local_location;
std::hash<std::string> hash_str;
auto* function_cache = UserFunctionCache::instance();
if (_conn_param.resource_name.empty()) {
// for jdbcExternalTable, _conn_param.resource_name == ""
// so, we use _conn_param.driver_path as key of jarpath
SCOPED_RAW_TIMER(&_jdbc_statistic._load_jar_timer);
RETURN_IF_ERROR(function_cache->get_jarpath(
std::abs((int64_t)hash_str(_conn_param.driver_path)), _conn_param.driver_path,
_conn_param.driver_checksum, &local_location));
} else {
SCOPED_RAW_TIMER(&_jdbc_statistic._load_jar_timer);
RETURN_IF_ERROR(function_cache->get_jarpath(
std::abs((int64_t)hash_str(_conn_param.resource_name)), _conn_param.driver_path,
_conn_param.driver_checksum, &local_location));
}
VLOG_QUERY << "driver local path = " << local_location;
std::string driver_path = _get_real_url(_conn_param.driver_path);

TJdbcExecutorCtorParams ctor_params;
ctor_params.__set_statement(_sql_str);
Expand All @@ -144,7 +128,8 @@ Status JdbcConnector::open(RuntimeState* state, bool read) {
ctor_params.__set_jdbc_user(_conn_param.user);
ctor_params.__set_jdbc_password(_conn_param.passwd);
ctor_params.__set_jdbc_driver_class(_conn_param.driver_class);
ctor_params.__set_driver_path(local_location);
ctor_params.__set_driver_path(driver_path);
ctor_params.__set_jdbc_driver_checksum(_conn_param.driver_checksum);
if (state == nullptr) {
ctor_params.__set_batch_size(read ? 1 : 0);
} else {
Expand Down Expand Up @@ -601,4 +586,12 @@ jobject JdbcConnector::_get_java_table_type(JNIEnv* env, TOdbcTableType::type ta
env->CallStaticObjectMethod(enumClass, findByValueMethod, static_cast<jint>(tableType));
return javaEnumObj;
}

std::string JdbcConnector::_get_real_url(const std::string& url) {
if (url.find(":/") == std::string::npos) {
return "file://" + config::jdbc_drivers_dir + "/" + url;
}
return url;
}

} // namespace doris::vectorized
2 changes: 2 additions & 0 deletions be/src/vec/exec/vjdbc_connector.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,8 @@ class JdbcConnector : public TableConnector {
int rows);
jobject _get_java_table_type(JNIEnv* env, TOdbcTableType::type tableType);

std::string _get_real_url(const std::string& url);

bool _closed = false;
jclass _executor_factory_clazz;
jclass _executor_clazz;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
package org.apache.doris.jdbc;

import org.apache.doris.common.exception.InternalException;
import org.apache.doris.common.jni.utils.UdfUtils;
import org.apache.doris.common.jni.vec.ColumnType;
import org.apache.doris.common.jni.vec.ColumnValueConverter;
import org.apache.doris.common.jni.vec.VectorColumn;
Expand All @@ -27,16 +26,25 @@
import org.apache.doris.thrift.TJdbcOperation;

import com.google.common.base.Preconditions;
import com.google.common.collect.Maps;
import com.zaxxer.hikari.HikariDataSource;
import org.apache.commons.codec.binary.Hex;
import org.apache.log4j.Logger;
import org.apache.thrift.TDeserializer;
import org.apache.thrift.TException;
import org.apache.thrift.protocol.TBinaryProtocol;
import org.semver4j.Semver;

import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStream;
import java.lang.reflect.Array;
import java.net.MalformedURLException;
import java.net.URL;
import java.net.URLClassLoader;
import java.net.URLConnection;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.sql.Connection;
import java.sql.DatabaseMetaData;
import java.sql.Date;
Expand All @@ -57,6 +65,7 @@ public abstract class BaseJdbcExecutor implements JdbcExecutor {
private static final TBinaryProtocol.Factory PROTOCOL_FACTORY = new TBinaryProtocol.Factory();
private HikariDataSource hikariDataSource = null;
private final byte[] hikariDataSourceLock = new byte[0];
private ClassLoader classLoader = null;
private Connection conn = null;
protected JdbcDataSourceConfig config;
protected PreparedStatement preparedStatement = null;
Expand All @@ -68,6 +77,7 @@ public abstract class BaseJdbcExecutor implements JdbcExecutor {
protected int batchSizeNum = 0;
protected int curBlockRows = 0;
protected String jdbcDriverVersion;
private static final Map<URL, ClassLoader> classLoaderMap = Maps.newConcurrentMap();

public BaseJdbcExecutor(byte[] thriftParams) throws Exception {
setJdbcDriverSystemProperties();
Expand All @@ -85,6 +95,7 @@ public BaseJdbcExecutor(byte[] thriftParams) throws Exception {
.setJdbcUrl(request.jdbc_url)
.setJdbcDriverUrl(request.driver_path)
.setJdbcDriverClass(request.jdbc_driver_class)
.setJdbcDriverChecksum(request.jdbc_driver_checksum)
.setBatchSize(request.batch_size)
.setOp(request.op)
.setTableType(request.table_type)
Expand Down Expand Up @@ -298,8 +309,7 @@ private void init(JdbcDataSourceConfig config, String sql) throws JdbcExecutorEx
ClassLoader oldClassLoader = Thread.currentThread().getContextClassLoader();
String hikariDataSourceKey = config.createCacheKey();
try {
ClassLoader parent = getClass().getClassLoader();
ClassLoader classLoader = UdfUtils.getClassLoader(config.getJdbcDriverUrl(), parent);
initializeClassLoader(config);
Thread.currentThread().setContextClassLoader(classLoader);
hikariDataSource = JdbcDataSource.getDataSource().getSource(hikariDataSourceKey);
if (hikariDataSource == null) {
Expand Down Expand Up @@ -357,6 +367,60 @@ private void init(JdbcDataSourceConfig config, String sql) throws JdbcExecutorEx
}
}

private synchronized void initializeClassLoader(JdbcDataSourceConfig config)
throws MalformedURLException, FileNotFoundException {
try {
URL[] urls = {new URL(config.getJdbcDriverUrl())};
if (classLoaderMap.containsKey(urls[0])) {
this.classLoader = classLoaderMap.get(urls[0]);
} else {
String expectedChecksum = config.getJdbcDriverChecksum();
String actualChecksum = computeObjectChecksum(urls[0].toString(), null);
if (!expectedChecksum.equals(actualChecksum)) {
throw new RuntimeException("Checksum mismatch for JDBC driver.");
}
ClassLoader parent = getClass().getClassLoader();
this.classLoader = URLClassLoader.newInstance(urls, parent);
classLoaderMap.put(urls[0], this.classLoader);
}
} catch (MalformedURLException e) {
throw new RuntimeException("Error loading JDBC driver.", e);
}
}

public static String computeObjectChecksum(String urlStr, String encodedAuthInfo) {
try (InputStream inputStream = getInputStreamFromUrl(urlStr, encodedAuthInfo, 10000, 10000)) {
MessageDigest digest = MessageDigest.getInstance("MD5");
byte[] buf = new byte[4096];
int bytesRead;
while ((bytesRead = inputStream.read(buf)) != -1) {
digest.update(buf, 0, bytesRead);
}
return Hex.encodeHexString(digest.digest());
} catch (IOException | NoSuchAlgorithmException e) {
throw new RuntimeException("Compute driver checksum from url: " + urlStr
+ " encountered an error: " + e.getMessage());
}
}

public static InputStream getInputStreamFromUrl(String urlStr, String encodedAuthInfo, int connectTimeoutMs,
int readTimeoutMs) throws IOException {
try {
URL url = new URL(urlStr);
URLConnection conn = url.openConnection();

if (encodedAuthInfo != null) {
conn.setRequestProperty("Authorization", "Basic " + encodedAuthInfo);
}

conn.setConnectTimeout(connectTimeoutMs);
conn.setReadTimeout(readTimeoutMs);
return conn.getInputStream();
} catch (Exception e) {
throw new IOException("Failed to open URL connection: " + urlStr, e);
}
}

protected void setValidationQuery(HikariDataSource ds) {
ds.setConnectionTestQuery("SELECT 1");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ public class JdbcDataSourceConfig {
private String jdbcPassword;
private String jdbcDriverUrl;
private String jdbcDriverClass;
private String jdbcDriverChecksum;
private int batchSize;
private TJdbcOperation op;
private TOdbcTableType tableType;
Expand Down Expand Up @@ -96,6 +97,15 @@ public JdbcDataSourceConfig setJdbcDriverClass(String jdbcDriverClass) {
return this;
}

public String getJdbcDriverChecksum() {
return jdbcDriverChecksum;
}

public JdbcDataSourceConfig setJdbcDriverChecksum(String jdbcDriverChecksum) {
this.jdbcDriverChecksum = jdbcDriverChecksum;
return this;
}

public int getBatchSize() {
return batchSize;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,12 @@ public static String computeObjectChecksum(String driverPath) throws DdlExceptio
}

public static String getFullDriverUrl(String driverUrl) throws IllegalArgumentException {
if (!(driverUrl.startsWith("file://") || driverUrl.startsWith("http://")
|| driverUrl.startsWith("https://") || driverUrl.matches("^[^:/]+\\.jar$"))) {
throw new IllegalArgumentException("Invalid driver URL format. Supported formats are: "
+ "file://xxx.jar, http://xxx.jar, https://xxx.jar, or xxx.jar (without prefix).");
}

try {
URI uri = new URI(driverUrl);
String schema = uri.getScheme();
Expand Down Expand Up @@ -481,4 +487,3 @@ public static void checkConnectionPoolProperties(int minSize, int maxSize, int m
}
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.junit.jupiter.api.Assertions;

import java.util.Map;

Expand Down Expand Up @@ -216,4 +217,54 @@ public void testJdbcDriverPtah() {
});
Assert.assertEquals("Driver URL does not match any allowed paths: file:///postgresql-42.5.0.jar", exception.getMessage());
}

@Test
public void testValidDriverUrls() {
String fileUrl = "file://path/to/driver.jar";
Assertions.assertDoesNotThrow(() -> {
String result = JdbcResource.getFullDriverUrl(fileUrl);
Assert.assertEquals(fileUrl, result);
});

String httpUrl = "http://example.com/driver.jar";
Assertions.assertDoesNotThrow(() -> {
String result = JdbcResource.getFullDriverUrl(httpUrl);
Assert.assertEquals(httpUrl, result);
});

String httpsUrl = "https://example.com/driver.jar";
Assertions.assertDoesNotThrow(() -> {
String result = JdbcResource.getFullDriverUrl(httpsUrl);
Assert.assertEquals(httpsUrl, result);
});

String jarFile = "driver.jar";
Assertions.assertDoesNotThrow(() -> {
String result = JdbcResource.getFullDriverUrl(jarFile);
Assert.assertTrue(result.startsWith("file://"));
});
}

@Test
public void testInvalidDriverUrls() {
String invalidUrl1 = "/mnt/path/to/driver.jar";
Assert.assertThrows(IllegalArgumentException.class, () -> {
JdbcResource.getFullDriverUrl(invalidUrl1);
});

String invalidUrl2 = "ftp://example.com/driver.jar";
Assert.assertThrows(IllegalArgumentException.class, () -> {
JdbcResource.getFullDriverUrl(invalidUrl2);
});

String invalidUrl3 = "";
Assert.assertThrows(IllegalArgumentException.class, () -> {
JdbcResource.getFullDriverUrl(invalidUrl3);
});

String invalidUrl4 = "example.com/driver";
Assert.assertThrows(IllegalArgumentException.class, () -> {
JdbcResource.getFullDriverUrl(invalidUrl4);
});
}
}
1 change: 1 addition & 0 deletions gensrc/thrift/Types.thrift
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,7 @@ struct TJdbcExecutorCtorParams {
14: optional i32 connection_pool_cache_clear_time
15: optional bool connection_pool_keep_alive
16: optional i64 catalog_id
17: optional string jdbc_driver_checksum
}

struct TJavaUdfExecutorCtorParams {
Expand Down

0 comments on commit 469bc77

Please sign in to comment.