Skip to content

Commit

Permalink
KSP: Fix interceptors with parameter defaults (#11103)
Browse files Browse the repository at this point in the history
  • Loading branch information
dstepanov authored Aug 22, 2024
1 parent d598010 commit dc3cbf8
Show file tree
Hide file tree
Showing 5 changed files with 169 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
import io.micronaut.inject.ast.FieldElement;
import io.micronaut.inject.ast.MethodElement;
import io.micronaut.inject.ast.ParameterElement;
import io.micronaut.inject.ast.PrimitiveElement;
import io.micronaut.inject.ast.TypedElement;
import io.micronaut.inject.configuration.ConfigurationMetadataBuilder;
import io.micronaut.inject.processing.JavaModelUtils;
Expand All @@ -62,6 +63,7 @@
import io.micronaut.inject.writer.ExecutableMethodsDefinitionWriter;
import io.micronaut.inject.writer.OriginatingElements;
import io.micronaut.inject.writer.ProxyingBeanDefinitionVisitor;
import io.micronaut.inject.writer.WriterUtils;
import org.objectweb.asm.ClassVisitor;
import org.objectweb.asm.ClassWriter;
import org.objectweb.asm.Label;
Expand Down Expand Up @@ -211,6 +213,8 @@ public class AopProxyWriter extends AbstractClassFileWriter implements ProxyingB
private boolean constructorRequiresReflection;
private MethodElement declaredConstructor;
private MethodElement newConstructor;
private String newConstructorSignature;
private List<Map.Entry<ParameterElement, Integer>> superConstructorParametersBinding;
private ParameterElement qualifierParameter;
private ParameterElement interceptorsListParameter;
private VisitorContext visitorContext;
Expand Down Expand Up @@ -491,15 +495,38 @@ private void initConstructor(MethodElement constructor) {
this.interceptorsListParameter = ParameterElement.of(interceptorList, INTERCEPTORS_PARAMETER);
ParameterElement interceptorRegistryParameter = ParameterElement.of(ClassElement.of(InterceptorRegistry.class), "$interceptorRegistry");
ClassElement proxyClass = ClassElement.of(proxyType.getClassName());

superConstructorParametersBinding = new ArrayList<>();
ParameterElement[] constructorParameters = constructor.getParameters();
List<ParameterElement> newConstructorParameters = new ArrayList<>(constructorParameters.length + 5);
newConstructorParameters.addAll(Arrays.asList(constructorParameters));
newConstructorParameters.add(ParameterElement.of(BeanResolutionContext.class, "$beanResolutionContext"));
newConstructorParameters.add(ParameterElement.of(BeanContext.class, "$beanContext"));
int superConstructorParameterIndex = 0;
for (ParameterElement newConstructorParameter : newConstructorParameters) {
superConstructorParametersBinding.add(Map.entry(newConstructorParameter, superConstructorParameterIndex++));
}

ParameterElement beanResolutionContext = ParameterElement.of(BeanResolutionContext.class, "$beanResolutionContext");
newConstructorParameters.add(beanResolutionContext);
ParameterElement beanContext = ParameterElement.of(BeanContext.class, "$beanContext");
newConstructorParameters.add(beanContext);
newConstructorParameters.add(qualifierParameter);
newConstructorParameters.add(interceptorsListParameter);
newConstructorParameters.add(interceptorRegistryParameter);
superConstructorParameterIndex += 5; // Skip internal parameters
if (WriterUtils.hasKotlinDefaultsParameters(List.of(constructorParameters))) {
List<ParameterElement> realNewConstructorParameters = new ArrayList<>(newConstructorParameters);
int count = WriterUtils.calculateNumberOfKotlinDefaultsMasks(List.of(constructorParameters));
for (int j = 0; j < count; j++) {
ParameterElement mask = ParameterElement.of(PrimitiveElement.INT, "mask" + j);
realNewConstructorParameters.add(mask);
superConstructorParametersBinding.add(Map.entry(mask, superConstructorParameterIndex++));
}
ParameterElement marker = ParameterElement.of(ClassElement.of("kotlin.jvm.internal.DefaultConstructorMarker"), "marker");
realNewConstructorParameters.add(marker);
superConstructorParametersBinding.add(Map.entry(marker, superConstructorParameterIndex));
this.newConstructorSignature = getConstructorDescriptor(realNewConstructorParameters);
} else {
this.newConstructorSignature = getConstructorDescriptor(newConstructorParameters);
}
this.newConstructor = MethodElement.of(
proxyClass,
constructor.getAnnotationMetadata(),
Expand All @@ -508,11 +535,11 @@ private void initConstructor(MethodElement constructor) {
"<init>",
newConstructorParameters.toArray(ZERO_PARAMETER_ELEMENTS)
);
this.beanResolutionContextArgumentIndex = constructorParameters.length;
this.beanContextArgumentIndex = constructorParameters.length + 1;
this.qualifierIndex = constructorParameters.length + 2;
this.interceptorsListArgumentIndex = constructorParameters.length + 3;
this.interceptorRegistryArgumentIndex = constructorParameters.length + 4;
this.beanResolutionContextArgumentIndex = newConstructorParameters.indexOf(beanResolutionContext);
this.beanContextArgumentIndex = newConstructorParameters.indexOf(beanContext);
this.qualifierIndex = newConstructorParameters.indexOf(qualifierParameter);
this.interceptorsListArgumentIndex = newConstructorParameters.indexOf(interceptorsListParameter);
this.interceptorRegistryArgumentIndex = newConstructorParameters.indexOf(interceptorRegistryParameter);
}

@NonNull
Expand Down Expand Up @@ -745,28 +772,30 @@ public void visitBeanDefinitionEnd() {
});
qualifierParameter.annotate(AnnotationUtil.NULLABLE);

String constructorDescriptor = getConstructorDescriptor(Arrays.asList(newConstructor.getParameters()));
ClassWriter proxyClassWriter = this.classWriter;
this.constructorWriter = proxyClassWriter.visitMethod(
ACC_PUBLIC,
CONSTRUCTOR_NAME,
constructorDescriptor,
newConstructorSignature,
null,
null);

this.constructorGenerator = new GeneratorAdapter(constructorWriter, ACC_PUBLIC, CONSTRUCTOR_NAME, constructorDescriptor);
this.constructorGenerator = new GeneratorAdapter(constructorWriter, ACC_PUBLIC, CONSTRUCTOR_NAME, newConstructorSignature);
GeneratorAdapter proxyConstructorGenerator = this.constructorGenerator;

proxyConstructorGenerator.loadThis();
if (isInterface) {
proxyConstructorGenerator.invokeConstructor(TYPE_OBJECT, METHOD_DEFAULT_CONSTRUCTOR);
} else {
ParameterElement[] existingArguments = declaredConstructor.getParameters();
for (int i = 0; i < existingArguments.length; i++) {
proxyConstructorGenerator.loadArg(i);
List<ParameterElement> arguments = new ArrayList<>();
for (Map.Entry<ParameterElement, Integer> e : superConstructorParametersBinding) {
proxyConstructorGenerator.loadArg(e.getValue());
arguments.add(e.getKey());
}
String superConstructorDescriptor = getConstructorDescriptor(Arrays.asList(existingArguments));
proxyConstructorGenerator.invokeConstructor(getTypeReferenceForName(targetClassFullName), new Method(CONSTRUCTOR_NAME, superConstructorDescriptor));
proxyConstructorGenerator.invokeConstructor(
getTypeReferenceForName(targetClassFullName),
new Method(CONSTRUCTOR_NAME, getConstructorDescriptor(arguments))
);
}

proxyBeanDefinitionWriter.visitBeanDefinitionConstructor(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3674,8 +3674,7 @@ private void visitBuildMethodDefinition(MethodElement constructor, boolean requi
final int parametersIndex = createConstructorParameterArray(parameters, buildMethodVisitor);
invokeConstructorChain(buildMethodVisitor, constructorIndex, parametersIndex, parameters);
} else {
boolean isKotlin = constructor.getClass().getSimpleName().startsWith("Kotlin");
if (isKotlin) {
if (WriterUtils.hasKotlinDefaultsParameters(parameters)) {
Map<Integer, Integer> checksLocals = new HashMap<>();
Map<Integer, Integer> valuesLocals = new HashMap<>();
WriterUtils.invokeBeanConstructor(buildMethodVisitor, constructor, requiresReflection, true, (index, parameter) -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,26 @@
public final class WriterUtils {
private static final String METHOD_NAME_INSTANTIATE = "instantiate";

/**
* The number of Kotlin defaults masks.
* @param parameters The parameters
* @return The number if masks
* @since 4.6.2
*/
public static int calculateNumberOfKotlinDefaultsMasks(List<ParameterElement> parameters) {
return (int) Math.ceil(parameters.size() / 32.0);
}

/**
* Checks if parameter include Kotlin defaults.
* @param arguments The arguments
* @return true if include
* @since 4.6.2
*/
public static boolean hasKotlinDefaultsParameters(List<ParameterElement> arguments) {
return arguments.stream().anyMatch(p -> p instanceof KotlinParameterElement kp && kp.hasDefault());
}

public static void invokeBeanConstructor(GeneratorAdapter writer,
MethodElement constructor,
boolean allowKotlinDefaults,
Expand All @@ -80,7 +100,7 @@ public static void invokeBeanConstructor(GeneratorAdapter writer,
Collection<Type> argumentTypes = constructorArguments.stream().map(pe ->
JavaModelUtils.getTypeReference(pe.getType())
).toList();
boolean isKotlinDefault = allowKotlinDefaults && constructorArguments.stream().anyMatch(p -> p instanceof KotlinParameterElement kp && kp.hasDefault());
boolean isKotlinDefault = allowKotlinDefaults && hasKotlinDefaultsParameters(constructorArguments);

int[] masksLocal = null;
if (isKotlinDefault) {
Expand Down Expand Up @@ -247,7 +267,7 @@ public static int[] computeKotlinDefaultsMask(GeneratorAdapter writer,
@Nullable
BiFunction<Integer, ParameterElement, Boolean> argumentValueIsPresentPusher,
List<ParameterElement> parameters) {
int numberOfMasks = (int) Math.ceil(parameters.size() / 32.0);
int numberOfMasks = calculateNumberOfKotlinDefaultsMasks(parameters);
int[] masksLocal = new int[numberOfMasks];
for (int i = 0; i < numberOfMasks; i++) {
int maskLocal = writer.newLocal(Type.INT_TYPE);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package io.micronaut.docs.server.defaults_intercepted

import io.micronaut.context.annotation.Requires
import io.micronaut.core.async.annotation.SingleResult
import io.micronaut.http.HttpResponse
import io.micronaut.http.MediaType
import io.micronaut.http.annotation.Body
import io.micronaut.http.annotation.Controller
import io.micronaut.http.annotation.Header
import io.micronaut.http.annotation.Post
import jakarta.validation.constraints.NotBlank
import jakarta.validation.constraints.NotNull
import jakarta.validation.constraints.Size
import org.reactivestreams.Publisher
import reactor.core.publisher.Flux
import spock.lang.Specification
import java.time.OffsetDateTime
import java.time.ZoneId

@Requires(property = "spec.name", value = "defaults-intercepted")
// tag::class[]
@Controller("/defaults-intercepted")
open class DefaultsInterceptedController(private val timeProvider: (ZoneId) -> OffsetDateTime = OffsetDateTime::now) {
// end::class[]

// tag::echo[]
@Post(value = "/echo", consumes = [MediaType.TEXT_PLAIN]) // <1>
@NotBlank
open fun echo(@Size(max = 1024) @NotNull @Body text: String, @Header("MYHEADER") someHeader : String = "THEDEFAULT"): String { // <2>
return someHeader // <3>
}
// end::echo[]

// tag::echoReactive[]
@Post(value = "/echo-publisher", consumes = [MediaType.TEXT_PLAIN]) // <1>
@SingleResult
open fun echoFlow(@Body text: Publisher<String>, @NotNull @Header("MYHEADER") someHeader : String = "THEDEFAULT"): Publisher<HttpResponse<String>> { //<2>
return Flux.from(text)
.collect({ StringBuffer() }, { obj, str -> obj.append(str) }) // <3>
.map { HttpResponse.ok(someHeader) }
}
// end::echoReactive[]

// tag::endclass[]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package io.micronaut.docs.server.defaults_intercepted

import io.kotest.core.spec.style.StringSpec
import io.kotest.matchers.shouldBe
import io.micronaut.context.ApplicationContext
import io.micronaut.http.HttpRequest
import io.micronaut.http.MediaType
import io.micronaut.http.client.HttpClient
import io.micronaut.runtime.server.EmbeddedServer

class DefaultsInterceptedControllerSpec : StringSpec() {

val embeddedServer = autoClose(
ApplicationContext.run(EmbeddedServer::class.java, mapOf("spec.name" to "defaults-intercepted"))
)

val client = autoClose(
embeddedServer.applicationContext.createBean(HttpClient::class.java, embeddedServer.getURL())
)

init {
"test echo response"() {
val response1 = client.toBlocking().retrieve(
HttpRequest.POST("/defaults-intercepted/echo", "My Text")
.header("MYHEADER", "abc123")
.contentType(MediaType.TEXT_PLAIN_TYPE), String::class.java
)

response1 shouldBe "abc123"

val response2 = client.toBlocking().retrieve(
HttpRequest.POST("/defaults-intercepted/echo", "My Text")
.contentType(MediaType.TEXT_PLAIN_TYPE), String::class.java
)

response2 shouldBe "THEDEFAULT"
}

"test echo reactive response"() {
val response1 = client.toBlocking().retrieve(
HttpRequest.POST("/defaults-intercepted/echo-publisher", "My Text")
.header("MYHEADER", "abc123")
.contentType(MediaType.TEXT_PLAIN_TYPE), String::class.java
)

response1 shouldBe "abc123"

val response2 = client.toBlocking().retrieve(
HttpRequest.POST("/defaults-intercepted/echo-publisher", "My Text")
.contentType(MediaType.TEXT_PLAIN_TYPE), String::class.java
)

response2 shouldBe "THEDEFAULT"
}
}
}

0 comments on commit dc3cbf8

Please sign in to comment.