diff --git a/graphql-dgs/src/main/kotlin/com/netflix/graphql/dgs/internal/DgsDataLoaderProvider.kt b/graphql-dgs/src/main/kotlin/com/netflix/graphql/dgs/internal/DgsDataLoaderProvider.kt index 1e89ed044..10b8f7d77 100644 --- a/graphql-dgs/src/main/kotlin/com/netflix/graphql/dgs/internal/DgsDataLoaderProvider.kt +++ b/graphql-dgs/src/main/kotlin/com/netflix/graphql/dgs/internal/DgsDataLoaderProvider.kt @@ -31,7 +31,6 @@ import org.dataloader.DataLoaderRegistry import org.dataloader.MappedBatchLoader import org.dataloader.MappedBatchLoaderWithContext import org.dataloader.registries.DispatchPredicate -import org.dataloader.registries.ScheduledDataLoaderRegistry import org.slf4j.Logger import org.slf4j.LoggerFactory import org.springframework.aop.support.AopUtils @@ -63,7 +62,7 @@ class DgsDataLoaderProvider( } fun buildRegistryWithContextSupplier(contextSupplier: Supplier): DataLoaderRegistry { - val registry = ScheduledDataLoaderRegistry.newScheduledRegistry().dispatchPredicate(DispatchPredicate.DISPATCH_NEVER).build() + val registry = DgsDataLoaderRegistry() val totalTime = measureTimeMillis { val extensionProviders = applicationContext .getBeanProvider(DataLoaderInstrumentationExtensionProvider::class.java) @@ -219,7 +218,7 @@ class DgsDataLoaderProvider( private fun registerDataLoader( holder: LoaderHolder<*>, - registry: ScheduledDataLoaderRegistry, + registry: DgsDataLoaderRegistry, contextSupplier: Supplier<*>, extensionProviders: Iterable ) { @@ -236,9 +235,9 @@ class DgsDataLoaderProvider( } if (holder.dispatchPredicate == null) { - registry.register(holder.name, loader, DispatchPredicate.DISPATCH_ALWAYS) + registry.register(holder.name, loader) } else { - registry.register(holder.name, loader, holder.dispatchPredicate) + registry.registerWithDispatchPredicate(holder.name, loader, holder.dispatchPredicate) } } diff --git a/graphql-dgs/src/main/kotlin/com/netflix/graphql/dgs/internal/DgsDataLoaderRegistry.kt b/graphql-dgs/src/main/kotlin/com/netflix/graphql/dgs/internal/DgsDataLoaderRegistry.kt new file mode 100644 index 000000000..6bfda6511 --- /dev/null +++ b/graphql-dgs/src/main/kotlin/com/netflix/graphql/dgs/internal/DgsDataLoaderRegistry.kt @@ -0,0 +1,198 @@ +/* + * Copyright 2023 Netflix, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.netflix.graphql.dgs.internal + +import org.dataloader.DataLoader +import org.dataloader.DataLoaderRegistry +import org.dataloader.registries.DispatchPredicate +import org.dataloader.registries.ScheduledDataLoaderRegistry +import org.dataloader.stats.Statistics +import java.util.concurrent.ConcurrentHashMap +import java.util.function.Function + +/** + * The DgsDataLoaderRegistry is a registry of DataLoaderRegistry instances. It supports specifying + * DispatchPredicate on a per data loader basis, specified using @DispatchPredicate annotation. It creates an instance + * of a ScheduledDataLoaderRegistry for every data loader that is registered and delegates to the mapping instance of + * the registry based on the key. We need to create a registry per data loader since a DispatchPredicate is applicable + * for an instance of the ScheduledDataLoaderRegistry. + * https://github.com/graphql-java/java-dataloader#scheduled-dispatching + */ +open class DgsDataLoaderRegistry : DataLoaderRegistry() { + private val scheduledDataLoaderRegistries: MutableMap = ConcurrentHashMap() + private val dataLoaderRegistry = DataLoaderRegistry() + + /** + * This will register a new dataloader + * + * @param key the key to put the data loader under + * @param dataLoader the data loader to register + * + * @return this registry + */ + override fun register(key: String, dataLoader: DataLoader<*, *>): DataLoaderRegistry { + dataLoaderRegistry.register(key, dataLoader) + return this + } + + /** + * This will register a new dataloader with a dispatch predicate set up for that loader + * + * @param key the key to put the data loader under + * @param dataLoader the data loader to register + * + * @return this registry + */ + fun registerWithDispatchPredicate( + key: String, + dataLoader: DataLoader<*, *>, + dispatchPredicate: DispatchPredicate + ): DataLoaderRegistry { + val registry = ScheduledDataLoaderRegistry.newScheduledRegistry().register(key, dataLoader) + .dispatchPredicate(dispatchPredicate) + .build() + scheduledDataLoaderRegistries.putIfAbsent(key, registry) + return this + } + + /** + * Computes a data loader if absent or return it if it was + * already registered at that key. + * + * + * Note: The entire method invocation is performed atomically, + * so the function is applied at most once per key. + * + * @param key the key of the data loader + * @param mappingFunction the function to compute a data loader + * @param the type of keys + * @param the type of values + * + * @return a data loader + */ + override fun computeIfAbsent( + key: String, + mappingFunction: Function>? + ): DataLoader { + // we do not support this method for registering with dispatch predicates + return dataLoaderRegistry.computeIfAbsent(key, mappingFunction!!) as DataLoader + } + + /** + * This operation is not supported since we cannot store a dataloader registry without a key. + */ + override fun combine(registry: DataLoaderRegistry): DataLoaderRegistry? { + throw UnsupportedOperationException("Cannot combine a DgsDataLoaderRegistry with another registry") + } + + /** + * @return the currently registered data loaders + */ + override fun getDataLoaders(): List> { + return scheduledDataLoaderRegistries.flatMap { it.value.dataLoaders }.plus(dataLoaderRegistry.dataLoaders) + } + + /** + * @return the currently registered data loaders as a map + */ + override fun getDataLoadersMap(): Map> { + var dataLoadersMap: Map> = emptyMap() + scheduledDataLoaderRegistries.forEach { + dataLoadersMap = dataLoadersMap.plus(it.value.dataLoadersMap) + } + return LinkedHashMap(dataLoadersMap.plus(dataLoaderRegistry.dataLoadersMap)) + } + + /** + * This will unregister a new dataloader + * + * @param key the key of the data loader to unregister + * + * @return this registry + */ + override fun unregister(key: String): DataLoaderRegistry { + scheduledDataLoaderRegistries.remove(key) + dataLoaderRegistry.unregister(key) + return this + } + + /** + * Returns the dataloader that was registered under the specified key + * + * @param key the key of the data loader + * @param the type of keys + * @param the type of values + * + * @return a data loader or null if its not present + */ + override fun getDataLoader(key: String): DataLoader? { + return dataLoaderRegistry.getDataLoader(key) ?: scheduledDataLoaderRegistries[key]?.getDataLoader(key) + } + + override fun getKeys(): Set { + return scheduledDataLoaderRegistries.keys.plus(dataLoaderRegistry.keys) + } + + /** + * This will be called [org.dataloader.DataLoader.dispatch] on each of the registered + * [org.dataloader.DataLoader]s + */ + override fun dispatchAll() { + scheduledDataLoaderRegistries.forEach { + it.value.dispatchAll() + } + dataLoaderRegistry.dispatchAll() + } + + /** + * Similar to [DataLoaderRegistry.dispatchAll], this calls [org.dataloader.DataLoader.dispatch] on + * each of the registered [org.dataloader.DataLoader]s, but returns the number of dispatches. + * + * @return total number of entries that were dispatched from registered [org.dataloader.DataLoader]s. + */ + override fun dispatchAllWithCount(): Int { + var sum = 0 + scheduledDataLoaderRegistries.forEach { + sum += it.value.dispatchAllWithCount() + } + sum += dataLoaderRegistry.dispatchAllWithCount() + return sum + } + + /** + * @return The sum of all batched key loads that need to be dispatched from all registered + * [org.dataloader.DataLoader]s + */ + override fun dispatchDepth(): Int { + var totalDispatchDepth = 0 + scheduledDataLoaderRegistries.forEach { + totalDispatchDepth += it.value.dispatchDepth() + } + totalDispatchDepth += dataLoaderRegistry.dispatchDepth() + + return totalDispatchDepth + } + + override fun getStatistics(): Statistics { + var stats = Statistics() + scheduledDataLoaderRegistries.forEach { + stats = stats.combine(it.value.statistics) + } + stats = stats.combine(dataLoaderRegistry.statistics) + return stats + } +} diff --git a/graphql-dgs/src/test/kotlin/com/netflix/graphql/dgs/internal/DgsDataLoaderRegistryTest.kt b/graphql-dgs/src/test/kotlin/com/netflix/graphql/dgs/internal/DgsDataLoaderRegistryTest.kt new file mode 100644 index 000000000..361e55cf6 --- /dev/null +++ b/graphql-dgs/src/test/kotlin/com/netflix/graphql/dgs/internal/DgsDataLoaderRegistryTest.kt @@ -0,0 +1,215 @@ +/* + * Copyright 2023 Netflix, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.netflix.graphql.dgs + +import com.netflix.graphql.dgs.internal.DgsDataLoaderRegistry +import io.mockk.every +import io.mockk.impl.annotations.MockK +import io.mockk.mockk +import org.assertj.core.api.Assertions.assertThat +import org.dataloader.BatchLoader +import org.dataloader.DataLoader +import org.dataloader.DataLoaderFactory +import org.dataloader.DataLoaderRegistry +import org.dataloader.registries.DispatchPredicate +import org.dataloader.stats.Statistics +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.assertThrows +import java.util.concurrent.CompletableFuture +import java.util.concurrent.CompletionStage + +class DgsDataLoaderRegistryTest { + + private val dgsDataLoaderRegistry = DgsDataLoaderRegistry() + private val dataLoaderA = ExampleDataLoaderA() + private val dataLoaderB = ExampleDataLoaderB() + + @MockK + var mockDataLoaderA: DataLoader = mockk() + + @MockK + var mockDataLoaderB: DataLoader = mockk() + + @Test + fun register() { + val newLoader = DataLoaderFactory.newDataLoader(dataLoaderA) + dgsDataLoaderRegistry.register("exampleLoaderA", newLoader) + assertThat(dgsDataLoaderRegistry.dataLoaders.size).isEqualTo(1) + val registeredLoader = dgsDataLoaderRegistry.getDataLoader("exampleLoaderA") + assertThat(registeredLoader).isNotNull + } + + @Test + fun registerWithScheduledDispatch() { + val newLoader = DataLoaderFactory.newDataLoader(dataLoaderB) + dgsDataLoaderRegistry.registerWithDispatchPredicate( + "exampleLoaderB", + newLoader, + DispatchPredicate.dispatchIfDepthGreaterThan(1) + ) + assertThat(dgsDataLoaderRegistry.dataLoaders.size).isEqualTo(1) + val registeredLoader = dgsDataLoaderRegistry.getDataLoader("exampleLoaderB") + assertThat(registeredLoader).isNotNull + } + + @Test + fun unregister() { + DataLoaderFactory.newDataLoader(dataLoaderA) + dgsDataLoaderRegistry.register("exampleLoaderA", DataLoaderFactory.newDataLoader(dataLoaderA)) + dgsDataLoaderRegistry.registerWithDispatchPredicate( + "exampleLoaderB", + DataLoaderFactory.newDataLoader(dataLoaderB), + DispatchPredicate.dispatchIfDepthGreaterThan(1) + ) + assertThat(dgsDataLoaderRegistry.dataLoaders.size).isEqualTo(2) + dgsDataLoaderRegistry.unregister("exampleLoaderA") + assertThat(dgsDataLoaderRegistry.dataLoaders.size).isEqualTo(1) + dgsDataLoaderRegistry.unregister("exampleLoaderB") + assertThat(dgsDataLoaderRegistry.dataLoaders.size).isEqualTo(0) + } + + @Test + fun combine() { + val error: UnsupportedOperationException = assertThrows { + dgsDataLoaderRegistry.combine(DataLoaderRegistry()) + } + } + + @Test + fun computeIfAbsent() { + val dataLoader = DataLoaderFactory.newDataLoader(dataLoaderA) as DataLoader<*, *> + dgsDataLoaderRegistry.computeIfAbsent("exampleLoader") { dataLoader } + + val loader = dgsDataLoaderRegistry.getDataLoader("exampleLoader") + assertThat(loader).isNotNull + } + + @Test + fun getDataLoaders() { + val newLoaderA = DataLoaderFactory.newDataLoader(dataLoaderA) + dgsDataLoaderRegistry.register("exampleLoaderA", newLoaderA) + + val newLoaderB = DataLoaderFactory.newDataLoader(dataLoaderB) + dgsDataLoaderRegistry.registerWithDispatchPredicate( + "exampleLoaderB", + newLoaderB, + DispatchPredicate.dispatchIfDepthGreaterThan(1) + ) + assertThat(dgsDataLoaderRegistry.dataLoaders.size).isEqualTo(2) + val registeredLoaderA = dgsDataLoaderRegistry.getDataLoader("exampleLoaderA") + assertThat(registeredLoaderA).isNotNull + val registeredLoaderB = dgsDataLoaderRegistry.getDataLoader("exampleLoaderB") + assertThat(registeredLoaderB).isNotNull + } + + @Test + fun getDataLoadersAsMap() { + val newLoaderA = DataLoaderFactory.newDataLoader(dataLoaderA) + dgsDataLoaderRegistry.register("exampleLoaderA", newLoaderA) + + val newLoaderB = DataLoaderFactory.newDataLoader(dataLoaderB) + dgsDataLoaderRegistry.registerWithDispatchPredicate( + "exampleLoaderB", + newLoaderB, + DispatchPredicate.dispatchIfDepthGreaterThan(1) + ) + + assertThat(dgsDataLoaderRegistry.dataLoadersMap.size).isEqualTo(2) + val registeredLoaderA = dgsDataLoaderRegistry.dataLoadersMap["exampleLoaderA"] + assertThat(registeredLoaderA).isNotNull + val registeredLoaderB = dgsDataLoaderRegistry.dataLoadersMap["exampleLoaderB"] + assertThat(registeredLoaderB).isNotNull + } + + @Test + fun dispatchAll() { + every { mockDataLoaderB.dispatchDepth() } returns 1 + every { mockDataLoaderB.dispatch() } returns CompletableFuture.completedFuture(emptyList()) + every { mockDataLoaderA.dispatch() } returns CompletableFuture.completedFuture(emptyList()) + + dgsDataLoaderRegistry.register("exampleLoaderA", mockDataLoaderA) + dgsDataLoaderRegistry.registerWithDispatchPredicate( + "exampleLoaderB", + mockDataLoaderB, + DispatchPredicate.dispatchIfDepthGreaterThan(1) + ) + dgsDataLoaderRegistry.dispatchAll() + } + + @Test + fun dispatchDepth() { + every { mockDataLoaderA.dispatchDepth() } returns 2 + every { mockDataLoaderB.dispatchDepth() } returns 1 + + dgsDataLoaderRegistry.register("exampleLoaderA", mockDataLoaderA) + dgsDataLoaderRegistry.registerWithDispatchPredicate( + "exampleLoaderB", + mockDataLoaderB, + DispatchPredicate.dispatchIfDepthGreaterThan(1) + ) + assertThat(dgsDataLoaderRegistry.dispatchDepth()).isEqualTo(3) + } + + @Test + fun dispatchAllWithCount() { + every { mockDataLoaderB.dispatchDepth() } returns 3 + every { mockDataLoaderA.dispatchWithCounts().keysCount } returns 4 + every { mockDataLoaderB.dispatchWithCounts().keysCount } returns 3 + + dgsDataLoaderRegistry.register("exampleLoaderA", mockDataLoaderA) + dgsDataLoaderRegistry.registerWithDispatchPredicate( + "exampleLoaderB", + mockDataLoaderB, + DispatchPredicate.dispatchIfDepthGreaterThan(1) + ) + assertThat(dgsDataLoaderRegistry.dispatchAllWithCount()).isEqualTo(7) + } + + @Test + fun getStatistics() { + every { mockDataLoaderA.statistics } returns Statistics(1, 1, 1, 1, 1, 1) + every { mockDataLoaderB.statistics } returns Statistics(2, 2, 2, 2, 2, 2) + + dgsDataLoaderRegistry.register("exampleLoaderA", mockDataLoaderA) + dgsDataLoaderRegistry.registerWithDispatchPredicate( + "exampleLoaderB", + mockDataLoaderB, + DispatchPredicate.dispatchIfDepthGreaterThan(1) + ) + assertThat(dgsDataLoaderRegistry.statistics).isNotNull + assertThat(dgsDataLoaderRegistry.statistics.loadCount).isEqualTo(3) + assertThat(dgsDataLoaderRegistry.statistics.loadErrorCount).isEqualTo(3) + assertThat(dgsDataLoaderRegistry.statistics.batchInvokeCount).isEqualTo(3) + assertThat(dgsDataLoaderRegistry.statistics.batchLoadCount).isEqualTo(3) + assertThat(dgsDataLoaderRegistry.statistics.batchLoadExceptionCount).isEqualTo(3) + assertThat(dgsDataLoaderRegistry.statistics.cacheHitCount).isEqualTo(3) + } + + @DgsDataLoader(name = "exampleLoaderA") + class ExampleDataLoaderA : BatchLoader { + override fun load(keys: List): CompletionStage> { + return CompletableFuture.completedFuture(listOf("A", "B", "C")) + } + } + + @DgsDataLoader(name = "exampleLoaderB") + class ExampleDataLoaderB : BatchLoader { + override fun load(keys: List): CompletionStage> { + return CompletableFuture.completedFuture(listOf("A", "B", "C")) + } + } +}