Skip to content

Commit

Permalink
🐛 Refactor plugin class loading to use custom PluginsClassloader. Thi…
Browse files Browse the repository at this point in the history
…s change ensures that plugin classes are consistently loaded by a single classloader, preventing conflicts arising from classes being loaded by multiple classloaders.
  • Loading branch information
ujibang committed Feb 13, 2024
1 parent d3c4e51 commit ff3c01c
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 57 deletions.
79 changes: 79 additions & 0 deletions core/src/main/java/org/restheart/plugins/PluginsClassloader.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
/*-
* ========================LICENSE_START=================================
* restheart-core
* %%
* Copyright (C) 2014 - 2024 SoftInstigate
* %%
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
* =========================LICENSE_END==================================
*/
package org.restheart.plugins;

import java.io.IOException;
import java.net.URL;
import java.net.URLClassLoader;

/**
* Loads a class, including searching within all plugin JAR files.
* <p>
* This method is essential for the {@code collectFieldInjections()} method, which processes field injections
* annotated with {@code @Inject}. Specifically, it addresses cases where the {@code @Inject} annotation
* references a {@link org.restheart.plugins.Provider} that returns an object. The class of this object may reside
* in a plugin JAR file, necessitating a comprehensive search to locate and load the class correctly.
* </p>
*/
public class PluginsClassloader extends ClassLoader {
private static PluginsClassloader SINGLETON = null;

/**
* call after PluginsScanner.jars array is populated
* @param jars
*/
public static void init(URL[] jars) {
if (SINGLETON != null) {
throw new IllegalStateException("already initialized");
} else {
try {
SINGLETON = new PluginsClassloader(jars);
} catch(IOException ioe) {
throw new RuntimeException("error initializing", ioe);
}
}
}

private final URLClassLoader pluginsClassLoader;

private PluginsClassloader(URL[] jars) throws IOException {
this.pluginsClassLoader = new URLClassLoader(jars);
}

public static PluginsClassloader getInstance() {
if (SINGLETON == null) {
throw new IllegalStateException("not initialized");
} else {
return SINGLETON;
}
}

@Override
public Class<?> loadClass(String name) throws ClassNotFoundException {
try {
// use the current classloader
return this.getClass().getClassLoader().loadClass(name);
} catch (ClassNotFoundException cnfe) {
// look in the plugins jars
return this.pluginsClassLoader.loadClass(name);
}
}
}
22 changes: 1 addition & 21 deletions core/src/main/java/org/restheart/plugins/PluginsFactory.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import static org.restheart.configuration.Utils.getOrDefault;

import java.lang.reflect.InvocationTargetException;
import java.net.URLClassLoader;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedHashSet;
Expand Down Expand Up @@ -56,15 +55,7 @@ public static PluginsFactory getInstance() {
return SINGLETON;
}

private final ArrayList<ClassLoader> classLoaders = new ArrayList<>();

private PluginsFactory() {
classLoaders.add(this.getClass().getClassLoader());

// take classloaders from PluginsScanner into account
if (PluginsScanner.jars != null) {
classLoaders.add(new URLClassLoader(PluginsScanner.jars));
}
}

private Set<PluginRecord<AuthMechanism>> authMechanismsCache = null;
Expand Down Expand Up @@ -283,18 +274,7 @@ private Class<Plugin> loadPluginClass(PluginDescriptor plugin) throws ClassNotFo
return PC_CACHE.get(plugin.clazz());
}

for (var classLoader : this.classLoaders) {
try {
var pluginc = (Class<Plugin>) classLoader.loadClass(plugin.clazz());

PC_CACHE.put(plugin.clazz(), pluginc);
return pluginc;
} catch (ClassNotFoundException cnfe) {
// nothing to do
}
}

throw new ClassNotFoundException("plugin class not found " + plugin.clazz());
return (Class<Plugin>) PluginsClassloader.getInstance().loadClass(plugin.clazz());
}

private Plugin instantiatePlugin(Class<Plugin> pc, String pluginType, String pluginName, Configuration conf) throws InstantiationException, IllegalAccessException, InvocationTargetException, IllegalArgumentException, SecurityException, ClassNotFoundException {
Expand Down
38 changes: 3 additions & 35 deletions core/src/main/java/org/restheart/plugins/PluginsScanner.java
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
import java.util.stream.Collectors;
import java.util.AbstractMap;

import org.checkerframework.common.returnsreceiver.qual.This;
import org.restheart.Bootstrapper;
import org.restheart.graal.NativeImageBuildTimeChecker;
import org.restheart.plugins.security.AuthMechanism;
Expand Down Expand Up @@ -300,7 +299,7 @@ private static ArrayList<InjectionDescriptor> collectFieldInjections(ClassInfo p
}

try {
var fieldClass = loadClass(fi.getTypeDescriptor().toString());
var fieldClass = PluginsClassloader.getInstance().loadClass(fi.getTypeDescriptor().toString());
ret.add(new FieldInjectionDescriptor(fi.getName(), fieldClass, annotationParams, fi.hashCode()));
} catch(ClassNotFoundException cnfe) {
// should not happen
Expand All @@ -312,39 +311,6 @@ private static ArrayList<InjectionDescriptor> collectFieldInjections(ClassInfo p
return ret;
}

private static ArrayList<ClassLoader> _classLoaders = null;

/**
* Loads a class, including searching within all plugin JAR files.
* <p>
* This method is essential for the {@code collectFieldInjections()} method, which processes field injections
* annotated with {@code @Inject}. Specifically, it addresses cases where the {@code @Inject} annotation
* references a {@link org.restheart.plugins.Provider} that returns an object. The class of this object may reside
* in a plugin JAR file, necessitating a comprehensive search to locate and load the class correctly.
* </p>
*/
private static Class<?> loadClass(String clazz) throws ClassNotFoundException {
if (_classLoaders == null || _classLoaders.isEmpty()) {
_classLoaders = new ArrayList<>();
_classLoaders.add(PluginsScanner.class.getClassLoader());

// take all classloaders into account to search also within all plugin JAR files
if (PluginsScanner.jars != null) {
_classLoaders.add(new URLClassLoader(PluginsScanner.jars));
}
}

for (var classLoader: _classLoaders) {
try {
return Class.forName(clazz, false, classLoader);
} catch (ClassNotFoundException cnfe) {
// nothing to do
}
}

throw new ClassNotFoundException("plugin class not found " + clazz);
}

/**
* this removes the reference to scanResult in the annotation info
* otherwise the huge object won't be garbage collected
Expand Down Expand Up @@ -373,6 +339,8 @@ public RuntimeClassGraph() {
this.jars = findPluginsJars(pdir);

if (jars != null && jars.length != 0) {
PluginsClassloader.init(jars);

this.classGraph = new ClassGraph().disableModuleScanning().disableDirScanning()
.disableNestedJarScanning().disableRuntimeInvisibleAnnotations()
.addClassLoader(new URLClassLoader(jars)).addClassLoader(ClassLoader.getSystemClassLoader())
Expand Down
5 changes: 4 additions & 1 deletion core/src/test/java/org/restheart/plugins/PluginsTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@

package org.restheart.plugins;

import java.net.URL;

import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.Mockito.mockStatic;
Expand Down Expand Up @@ -121,6 +123,7 @@ public void validProviders() {
}

private static MockedStatic<PluginsScanner> mockPluginsScanner(List<PluginDescriptor> providerDescriptors) {
PluginsClassloader.init(new URL[0]);
var scanner = mockStatic(PluginsScanner.class);
scanner.when(PluginsScanner::providers).thenReturn(providerDescriptors);
return scanner;
Expand Down Expand Up @@ -204,7 +207,7 @@ private static List<PluginDescriptor> providerDescriptors() {

var iC2 = new ArrayList<InjectionDescriptor>();
var apC2 = new ArrayList<AbstractMap.SimpleEntry<String, Object>>();
apC2.add(new AbstractMap.SimpleEntry<String, Object>("value", "c1"));
apC2.add(new AbstractMap.SimpleEntry<>("value", "c1"));
iC2.add(new FieldInjectionDescriptor("s", String.class, apC2, 8));

var providerDescriptors = new ArrayList<PluginDescriptor>();
Expand Down

0 comments on commit ff3c01c

Please sign in to comment.