Skip to content

Commit

Permalink
refactor: optimize dependency graph performances (#4587)
Browse files Browse the repository at this point in the history
  • Loading branch information
ndr-brt authored Oct 28, 2024
1 parent 55405aa commit e6b6560
Show file tree
Hide file tree
Showing 8 changed files with 93 additions and 102 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
Expand All @@ -44,6 +43,10 @@
import java.util.stream.Stream;

import static java.util.Optional.ofNullable;
import static java.util.function.Function.identity;
import static java.util.stream.Collectors.toList;
import static java.util.stream.Collectors.toMap;
import static java.util.stream.Collectors.toSet;


/**
Expand Down Expand Up @@ -77,57 +80,55 @@ public List<InjectionContainer<ServiceExtension>> of(List<ServiceExtension> load

// check if all injected fields are satisfied, collect missing ones and throw exception otherwise
var unsatisfiedInjectionPoints = new ArrayList<InjectionPoint<ServiceExtension>>();
var unsatisfiedRequirements = new ArrayList<String>();

var injectionPoints = extensions.stream()
.flatMap(ext -> getInjectedFields(ext).stream()
.peek(injectionPoint -> {
if (!canResolve(dependencyMap, injectionPoint.getType())) {
if (injectionPoint.isRequired()) {
unsatisfiedInjectionPoints.add(injectionPoint);
.collect(toMap(identity(), ext -> {

//check that all the @Required features are there
getRequiredFeatures(ext.getClass()).forEach(feature -> {
var dependencies = dependencyMap.get(feature);
if (dependencies == null) {
unsatisfiedRequirements.add(feature.getName());
} else {
dependencies.forEach(dependency -> sort.addDependency(ext, dependency));
}
});

return injectionPointScanner.getInjectionPoints(ext)
.peek(injectionPoint -> {
if (!canResolve(dependencyMap, injectionPoint.getType())) {
if (injectionPoint.isRequired()) {
unsatisfiedInjectionPoints.add(injectionPoint);
}
} else {
// get() would return null, if the feature is already in the context's service list
ofNullable(dependencyMap.get(injectionPoint.getType()))
.ifPresent(l -> l.stream()
.filter(d -> !Objects.equals(d, ext)) // remove dependencies onto oneself
.forEach(provider -> sort.addDependency(ext, provider)));
}
} else {
// get() would return null, if the feature is already in the context's service list
ofNullable(dependencyMap.get(injectionPoint.getType()))
.ifPresent(l -> l.stream()
.filter(d -> !Objects.equals(d, ext)) // remove dependencies onto oneself
.forEach(provider -> sort.addDependency(ext, provider)));
}
})
)
.collect(Collectors.toList());

//throw an exception if still unsatisfied links
})
.collect(toSet());
}));

if (!unsatisfiedInjectionPoints.isEmpty()) {
var string = "The following injected fields were not provided:\n";
string += unsatisfiedInjectionPoints.stream().map(InjectionPoint::toString).collect(Collectors.joining("\n"));
throw new EdcInjectionException(string);
}

//check that all the @Required features are there
var unsatisfiedRequirements = new ArrayList<String>();
extensions.forEach(ext -> {
var features = getRequiredFeatures(ext.getClass());
features.forEach(feature -> {
var dependencies = dependencyMap.get(feature);
if (dependencies == null) {
unsatisfiedRequirements.add(feature.getName());
} else {
dependencies.forEach(dependency -> sort.addDependency(ext, dependency));
}
});
});

if (!unsatisfiedRequirements.isEmpty()) {
var string = String.format("The following @Require'd features were not provided: [%s]", String.join(", ", unsatisfiedRequirements));
throw new EdcException(string);
}

sort.sort(extensions);

// todo: should the list of InjectionContainers be generated directly by the flatmap?
// convert the sorted list of extensions into an equally sorted list of InjectionContainers
return extensions.stream()
.map(se -> new InjectionContainer<>(se, injectionPoints.stream().filter(ip -> ip.getInstance() == se).collect(Collectors.toSet())))
.collect(Collectors.toList());
.map(key -> new InjectionContainer<>(key, injectionPoints.get(key)))
.toList();
}

private boolean canResolve(Map<Class<?>, List<ServiceExtension>> dependencyMap, Class<?> featureName) {
Expand All @@ -142,18 +143,17 @@ private boolean canResolve(Map<Class<?>, List<ServiceExtension>> dependencyMap,

private Map<Class<?>, List<ServiceExtension>> createDependencyMap(List<ServiceExtension> extensions) {
Map<Class<?>, List<ServiceExtension>> dependencyMap = new HashMap<>();
extensions.forEach(ext -> getDefaultProvidedFeatures(ext).forEach(feature -> dependencyMap.computeIfAbsent(feature, k -> new ArrayList<>()).add(ext)));
extensions.forEach(ext -> getProvidedFeatures(ext).forEach(feature -> dependencyMap.computeIfAbsent(feature, k -> new ArrayList<>()).add(ext)));
return dependencyMap;
}

private Set<Class<?>> getRequiredFeatures(Class<?> clazz) {
private Stream<Class<?>> getRequiredFeatures(Class<?> clazz) {
var requiresAnnotation = clazz.getAnnotation(Requires.class);
if (requiresAnnotation != null) {
var features = requiresAnnotation.value();
return Stream.of(features).collect(Collectors.toSet());
return Stream.of(features);
}
return Collections.emptySet();
return Stream.empty();
}

/**
Expand All @@ -165,42 +165,23 @@ private Set<Class<?>> getProvidedFeatures(ServiceExtension ext) {
// check all @Provides
var providesAnnotation = ext.getClass().getAnnotation(Provides.class);
if (providesAnnotation != null) {
var featureStrings = Arrays.stream(providesAnnotation.value()).collect(Collectors.toSet());
allProvides.addAll(featureStrings);
allProvides.addAll(Arrays.asList(providesAnnotation.value()));
}

// check all @Provider methods
allProvides.addAll(new ProviderMethodScanner(ext).nonDefaultProviders().stream().map(ProviderMethod::getReturnType).collect(Collectors.toSet()));
new ProviderMethodScanner(ext).allProviders().map(ProviderMethod::getReturnType).forEach(allProvides::add);
return allProvides;
}

private Set<Class<?>> getDefaultProvidedFeatures(ServiceExtension ext) {
return new ProviderMethodScanner(ext).defaultProviders().stream()
.map(ProviderMethod::getReturnType)
.collect(Collectors.toSet());
}

/**
* Handles core-, transfer- and contract-extensions and inserts them at the beginning of the list so that
* explicit @Requires annotations are not necessary
*/
private List<ServiceExtension> sortByType(List<ServiceExtension> loadedExtensions) {
var baseDependencies = loadedExtensions.stream().filter(e -> e.getClass().getAnnotation(BaseExtension.class) != null).collect(Collectors.toList());
if (baseDependencies.isEmpty()) {
throw new EdcException("No base dependencies were found on the classpath. Please add the \"core:common:connector-core\" module to your classpath!");
}

return loadedExtensions.stream().sorted(new ServiceExtensionComparator()).collect(Collectors.toList());
}

/**
* Obtains all features a specific extension provides as strings
*/
private Set<InjectionPoint<ServiceExtension>> getInjectedFields(ServiceExtension ext) {
// initialize with legacy list
return injectionPointScanner.getInjectionPoints(ext);
return loadedExtensions.stream().sorted(new SortByType()).collect(toList());
}

private static class ServiceExtensionComparator implements Comparator<ServiceExtension> {
private static class SortByType implements Comparator<ServiceExtension> {
@Override
public int compare(ServiceExtension o1, ServiceExtension o2) {
return orderFor(o1.getClass()).compareTo(orderFor(o2.getClass()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,15 @@

import java.util.Arrays;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;

/**
* Scans a particular (partly constructed) object for fields that are annotated with {@link Inject} and returns them
* in a {@link Set}
*/
public class InjectionPointScanner {
public <T> Set<InjectionPoint<T>> getInjectionPoints(T instance) {

public <T> Stream<InjectionPoint<T>> getInjectionPoints(T instance) {

var targetClass = instance.getClass();

Expand All @@ -34,7 +35,6 @@ public <T> Set<InjectionPoint<T>> getInjectionPoints(T instance) {
.map(f -> {
var isRequired = f.getAnnotation(Inject.class).required();
return new FieldInjectionPoint<>(instance, f, isRequired);
})
.collect(Collectors.toSet());
});
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@
import org.eclipse.edc.spi.system.ServiceExtension;

import java.util.Arrays;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import static java.lang.reflect.Modifier.isPublic;

Expand All @@ -33,34 +32,38 @@ public ProviderMethodScanner(ServiceExtension target) {
this.target = target;
}

/**
* Returns all methods annotated with {@link Provider}.
*/
public Stream<ProviderMethod> allProviders() {
return getProviderMethods(target);
}

/**
* Returns all methods annotated with {@link Provider}, where {@link Provider#isDefault()} is {@code false}
*/
public Set<ProviderMethod> nonDefaultProviders() {
return getProviderMethods(target).stream().filter(pm -> !pm.isDefault()).collect(Collectors.toSet());
public Stream<ProviderMethod> nonDefaultProviders() {
return getProviderMethods(target).filter(pm -> !pm.isDefault());
}

/**
* Returns all methods annotated with {@link Provider}, where {@link Provider#isDefault()} is {@code true}
*/
public Set<ProviderMethod> defaultProviders() {
return getProviderMethods(target).stream().filter(ProviderMethod::isDefault).collect(Collectors.toSet());
public Stream<ProviderMethod> defaultProviders() {
return getProviderMethods(target).filter(ProviderMethod::isDefault);
}

private Set<ProviderMethod> getProviderMethods(Object extension) {
var methods = Arrays.stream(extension.getClass().getDeclaredMethods())
private Stream<ProviderMethod> getProviderMethods(Object extension) {
return Arrays.stream(extension.getClass().getDeclaredMethods())
.filter(m -> m.getAnnotation(Provider.class) != null)
.map(ProviderMethod::new)
.collect(Collectors.toSet());

if (methods.stream().anyMatch(m -> m.getReturnType().equals(Void.TYPE))) {
throw new EdcInjectionException("Methods annotated with @Provider must have a non-void return type!");
}
if (methods.stream().anyMatch(m -> !isPublic(m.getMethod().getModifiers()))) {
throw new EdcInjectionException("Methods annotated with @Provider must be public!");
}
return methods;
.peek(method -> {
if (method.getReturnType().equals(Void.TYPE)) {
throw new EdcInjectionException("Methods annotated with @Provider must have a non-void return type!");
}
if (!isPublic(method.getMethod().getModifiers())) {
throw new EdcInjectionException("Methods annotated with @Provider must be public!");
}
});
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,16 @@ void setup() {
scanner = new ProviderMethodScanner(new TestExtension());
}

@Test
void allProviders() {
assertThat(scanner.allProviders()).hasSize(3);
}

@Test
void providerMethods() {
assertThat(scanner
.nonDefaultProviders())
.hasSize(2);

}

@Test
Expand All @@ -52,16 +56,16 @@ void defaultProviderMethods() throws NoSuchMethodException {
@Test
void verifyInvalidReturnType() {
var scanner = new ProviderMethodScanner(new InvalidTestExtension());
assertThatThrownBy(scanner::nonDefaultProviders).isInstanceOf(EdcInjectionException.class);
assertThatThrownBy(scanner::defaultProviders).isInstanceOf(EdcInjectionException.class);
assertThatThrownBy(() -> scanner.nonDefaultProviders().toList()).isInstanceOf(EdcInjectionException.class);
assertThatThrownBy(() -> scanner.defaultProviders().toList()).isInstanceOf(EdcInjectionException.class);
}

@Test
void verifyInvalidVisibility() {
var scanner = new ProviderMethodScanner(new InvalidTestExtension2());

assertThatThrownBy(scanner::nonDefaultProviders).isInstanceOf(EdcInjectionException.class);
assertThatThrownBy(scanner::defaultProviders).isInstanceOf(EdcInjectionException.class);
assertThatThrownBy(() -> scanner.nonDefaultProviders().toList()).isInstanceOf(EdcInjectionException.class);
assertThatThrownBy(() -> scanner.defaultProviders().toList()).isInstanceOf(EdcInjectionException.class);
}

private static class TestExtension implements ServiceExtension {
Expand Down Expand Up @@ -99,4 +103,4 @@ Object invalidProvider() {
}
}

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import org.eclipse.edc.spi.system.ServiceExtension;
import org.junit.jupiter.api.Test;

import java.util.Set;
import java.util.stream.Stream;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
Expand All @@ -32,9 +32,10 @@

class RegistrationPhaseTest extends PhaseTest {

private final ProviderMethodScanner scannerMock = mock();

@Test
void registerProviders_noProviderMethod() {
ProviderMethodScanner scannerMock = mock(ProviderMethodScanner.class);
var rp = new RegistrationPhase(new Phase(injector, container, context, monitor) {
}, scannerMock);
when(container.getInjectionTarget()).thenReturn(mock(ServiceExtension.class));
Expand All @@ -53,7 +54,7 @@ void registerProviders_withProvider_noDefault_notRegistered() {
when(providerMethod.invoke(any(), any())).thenReturn(new TestService());
when(providerMethod.getReturnType()).thenAnswer(a -> TestService.class);
when(context.hasService(TestService.class)).thenReturn(false);
when(scannerMock.nonDefaultProviders()).thenReturn(Set.of(providerMethod));
when(scannerMock.nonDefaultProviders()).thenReturn(Stream.of(providerMethod));

var rp = new RegistrationPhase(new Phase(injector, container, context, monitor) {
}, scannerMock);
Expand All @@ -69,7 +70,7 @@ void registerProviders_withProvider_noDefault_notRegistered() {
@Test
void registerProviders_withProvider_isDefault_notRegistered() {
var scannerMock = mock(ProviderMethodScanner.class);
when(scannerMock.nonDefaultProviders()).thenReturn(Set.of());
when(scannerMock.nonDefaultProviders()).thenReturn(Stream.empty());

var rp = new RegistrationPhase(new Phase(injector, container, context, monitor) {
}, scannerMock);
Expand Down Expand Up @@ -105,4 +106,4 @@ void registerProviders_withProvider_isDefault_isRegistered() {
private static class TestService {

}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import org.eclipse.edc.boot.system.injection.InjectionPointScanner;
import org.eclipse.edc.boot.system.injection.InjectorImpl;
import org.eclipse.edc.boot.system.injection.ObjectFactory;
import org.eclipse.edc.boot.system.injection.ReflectiveObjectFactory;
import org.eclipse.edc.boot.system.runtime.BaseRuntime;
import org.eclipse.edc.spi.system.ServiceExtension;
import org.eclipse.edc.spi.system.ServiceExtensionContext;
Expand Down
Loading

0 comments on commit e6b6560

Please sign in to comment.