diff --git a/rewrite-java-test/src/test/java/org/openrewrite/java/JavaTemplateAnnotationTest.java b/rewrite-java-test/src/test/java/org/openrewrite/java/JavaTemplateAnnotationTest.java index 6629eea4490..9b174dd6307 100644 --- a/rewrite-java-test/src/test/java/org/openrewrite/java/JavaTemplateAnnotationTest.java +++ b/rewrite-java-test/src/test/java/org/openrewrite/java/JavaTemplateAnnotationTest.java @@ -33,7 +33,7 @@ class JavaTemplateAnnotationTest implements RewriteTest { @DocumentExample @Test - void replaceAnnotation() { + void replaceClassAnnotation() { rewriteRun( spec -> spec.expectedCyclesThatMakeChanges(2) .recipe(toRecipe(() -> new JavaVisitor<>() { @@ -49,7 +49,7 @@ public J visitAnnotation(J.Annotation annotation, ExecutionContext executionCont @Deprecated(since = "1.0", forRemoval = true) class A { } - """, + """, """ @Deprecated(since = "2.0", forRemoval = true) class A { @@ -59,6 +59,37 @@ class A { ); } + @Test + void replaceNestedClassAnnotation() { + rewriteRun( + spec -> spec.expectedCyclesThatMakeChanges(2) + .recipe(toRecipe(() -> new JavaVisitor<>() { + @Override + public J visitAnnotation(J.Annotation annotation, ExecutionContext executionContext) { + return JavaTemplate.apply("@Deprecated(since = \"#{}\", forRemoval = true)", + getCursor(), annotation.getCoordinates().replace(), "2.0"); + } + } + )), + java( + """ + class A { + @Deprecated(since = "1.0", forRemoval = true) + class B { + } + } + """, + """ + class A { + @Deprecated(since = "2.0", forRemoval = true) + class B { + } + } + """ + ) + ); + } + @ExpectedToFail @Test void replaceAnnotation2() { diff --git a/rewrite-java/src/main/java/org/openrewrite/java/internal/template/AnnotationTemplateGenerator.java b/rewrite-java/src/main/java/org/openrewrite/java/internal/template/AnnotationTemplateGenerator.java index 45f9083bb2a..7470f0ecaef 100644 --- a/rewrite-java/src/main/java/org/openrewrite/java/internal/template/AnnotationTemplateGenerator.java +++ b/rewrite-java/src/main/java/org/openrewrite/java/internal/template/AnnotationTemplateGenerator.java @@ -22,6 +22,7 @@ import org.openrewrite.Cursor; import org.openrewrite.internal.ListUtils; import org.openrewrite.java.JavaIsoVisitor; +import org.openrewrite.java.service.AnnotationService; import org.openrewrite.java.tree.*; import java.util.ArrayList; @@ -117,7 +118,7 @@ public J.Annotation visitAnnotation(J.Annotation annotation, Integer integer) { private void template(Cursor cursor, J prior, StringBuilder before, StringBuilder after, Set templated) { templated.add(cursor.getValue()); - J j = cursor.getValue(); + final J j = cursor.getValue(); if (j instanceof JavaSourceFile) { JavaSourceFile cu = (JavaSourceFile) j; for (J.Import anImport : cu.getImports()) { @@ -132,16 +133,14 @@ private void template(Cursor cursor, J prior, StringBuilder before, StringBuilde } List classes = cu.getClasses(); if (!classes.get(classes.size() - 1).getName().getSimpleName().equals("$Placeholder")) { - after.append("@interface $Placeholder {}"); + after.append("\n@interface $Placeholder {}"); } return; - } - if (j instanceof J.Block) { + } else if (j instanceof J.ClassDeclaration) { + classDeclaration(before, after, (J.ClassDeclaration) j, templated, cursor, prior); + } else if (j instanceof J.Block) { J parent = next(cursor).getValue(); - if (parent instanceof J.ClassDeclaration) { - classDeclaration(before, (J.ClassDeclaration) parent, templated, cursor); - after.append('}'); - } else if (parent instanceof J.MethodDeclaration) { + if (parent instanceof J.MethodDeclaration) { J.MethodDeclaration m = (J.MethodDeclaration) parent; // variable declarations up to the point of insertion @@ -200,29 +199,46 @@ private void template(Cursor cursor, J prior, StringBuilder before, StringBuilde template(next(cursor), j, before, after, templated); } - private void classDeclaration(StringBuilder before, J.ClassDeclaration parent, Set templated, Cursor cursor) { + private void classDeclaration(StringBuilder before, StringBuilder after, J.ClassDeclaration parent, Set templated, Cursor cursor, J prior) { J.ClassDeclaration c = parent; - for (Statement statement : c.getBody().getStatements()) { - if (templated.contains(statement)) { - continue; - } + boolean annotated = isAnnotated(cursor, prior); + if (!annotated) { + for (Statement statement : c.getBody().getStatements()) { + if (templated.contains(statement)) { + continue; + } - if (statement instanceof J.VariableDeclarations) { - J.VariableDeclarations v = (J.VariableDeclarations) statement; - if (v.hasModifier(J.Modifier.Type.Final) && v.hasModifier(J.Modifier.Type.Static)) { - before.insert(0, variable((J.VariableDeclarations) statement, cursor) + ";\n"); + if (statement instanceof J.VariableDeclarations) { + J.VariableDeclarations v = (J.VariableDeclarations) statement; + if (v.hasModifier(J.Modifier.Type.Final) && v.hasModifier(J.Modifier.Type.Static)) { + before.insert(0, variable((J.VariableDeclarations) statement, cursor) + ";\n"); + } + } else if (statement instanceof J.ClassDeclaration) { + // this is a sibling class. we need declarations for all variables and methods. + // setting prior to null will cause them all to be written. + before.insert(0, '}'); + classDeclaration(before, after, (J.ClassDeclaration) statement, templated, cursor, prior); } - } else if (statement instanceof J.ClassDeclaration) { - // this is a sibling class. we need declarations for all variables and methods. - // setting prior to null will cause them all to be written. - before.insert(0, '}'); - classDeclaration(before, (J.ClassDeclaration) statement, templated, cursor); } } c = c.withBody(J.Block.createEmptyBlock()).withLeadingAnnotations(null).withPrefix(Space.EMPTY); String printed = c.printTrimmed(cursor); int braceIndex = printed.lastIndexOf('{'); - before.insert(0, braceIndex == -1 ? printed + '{' : printed.substring(0, braceIndex + 1)); + if (annotated) { + after.append(braceIndex == -1 ? printed + '{' : printed.substring(0, braceIndex + 1)); + } else { + before.insert(0, braceIndex == -1 ? printed + '{' : printed.substring(0, braceIndex + 1)); + } + after.append('}'); + } + + private static boolean isAnnotated(Cursor cursor, J maybeAnnotation) { + if (!(maybeAnnotation instanceof J.Annotation)) { + return false; + } + Cursor sourceFileCursor = cursor.dropParentUntil(is -> is instanceof JavaSourceFile); + AnnotationService annotationService = sourceFileCursor.getValue().service(AnnotationService.class); + return annotationService.getAllAnnotations(cursor).contains(maybeAnnotation); } private String variable(J.VariableDeclarations variable, Cursor cursor) {