Skip to content

Commit

Permalink
Support JEP-440: Pattern matching for Records (#4925)
Browse files Browse the repository at this point in the history
* Implement handling of record pattern matching

---------

Co-authored-by: Tim te Beek <[email protected]>
  • Loading branch information
Laurens-W and timtebeek authored Jan 23, 2025
1 parent d10d0ec commit c7a28a6
Show file tree
Hide file tree
Showing 12 changed files with 181 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -785,10 +785,23 @@ public J visitInstanceOf(InstanceOfTree node, Space fmt) {
type);
}

@Override
public J visitDeconstructionPattern(DeconstructionPatternTree node, Space fmt) {
JavaType type = typeMapping.type(node);
return new J.DeconstructionPattern(randomId(),
fmt,
Markers.EMPTY,
convert(node.getDeconstructor()),
JContainer.build(sourceBefore("("), convertAll(node.getNestedPatterns(), commaDelim, t -> sourceBefore(")")), Markers.EMPTY),
type);
}

private @Nullable J getNodePattern(@Nullable PatternTree pattern, JavaType type) {
if (pattern instanceof JCBindingPattern b) {
return new J.Identifier(randomId(), sourceBefore(b.getVariable().getName().toString()), Markers.EMPTY, emptyList(), b.getVariable().getName().toString(),
type, typeMapping.variableType(b.var.sym));
} else if (pattern instanceof DeconstructionPatternTree r) {
return visitDeconstructionPattern(r, whitespace());
} else {
if (pattern == null) {
return null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -291,8 +291,8 @@ private JavaType.FullyQualified classType(Type.ClassType classType, String signa
if (sym.members_field != null) {
for (Symbol elem : sym.members_field.getSymbols()) {
if (elem instanceof Symbol.VarSymbol &&
(elem.flags_field & (Flags.SYNTHETIC | Flags.BRIDGE | Flags.HYPOTHETICAL |
Flags.GENERATEDCONSTR | Flags.ANONCONSTR)) == 0) {
(elem.flags_field & (Flags.SYNTHETIC | Flags.BRIDGE | Flags.HYPOTHETICAL |
Flags.GENERATEDCONSTR | Flags.ANONCONSTR)) == 0) {
if (fqn.equals("java.lang.String") && elem.name.toString().equals("serialPersistentFields")) {
// there is a "serialPersistentFields" member within the String class which is used in normal Java
// serialization to customize how the String field is serialized. This field is tripping up Jackson
Expand All @@ -305,7 +305,7 @@ private JavaType.FullyQualified classType(Type.ClassType classType, String signa
}
fields.add(variableType(elem, clazz));
} else if (elem instanceof Symbol.MethodSymbol &&
(elem.flags_field & (Flags.SYNTHETIC | Flags.BRIDGE | Flags.HYPOTHETICAL | Flags.ANONCONSTR)) == 0) {
(elem.flags_field & (Flags.SYNTHETIC | Flags.BRIDGE | Flags.HYPOTHETICAL | Flags.ANONCONSTR)) == 0) {
if (methods == null) {
methods = new ArrayList<>();
}
Expand Down Expand Up @@ -376,6 +376,8 @@ private JavaType.FullyQualified.Kind getKind(Symbol.ClassSymbol sym) {
return variableType(((JCTree.JCVariableDecl) tree).sym);
} else if (tree instanceof JCTree.JCAnnotatedType && ((JCTree.JCAnnotatedType) tree).getUnderlyingType() instanceof JCTree.JCArrayTypeTree) {
return annotatedArray((JCTree.JCAnnotatedType) tree);
} else if (tree instanceof JCTree.JCRecordPattern) {
symbol = ((JCTree.JCRecordPattern) tree).record;
}

return type(((JCTree) tree).type, symbol);
Expand Down Expand Up @@ -424,7 +426,7 @@ public JavaType.Primitive primitive(TypeTag tag) {
}

private JavaType.@Nullable Variable variableType(@Nullable Symbol symbol,
JavaType.@Nullable FullyQualified owner) {
JavaType.@Nullable FullyQualified owner) {
if (!(symbol instanceof Symbol.VarSymbol)) {
return null;
}
Expand Down Expand Up @@ -598,23 +600,23 @@ public JavaType.Primitive primitive(TypeTag tag) {
}
}
List<String> defaultValues = null;
if(methodSymbol.getDefaultValue() != null) {
if(methodSymbol.getDefaultValue() instanceof Attribute.Array) {
if (methodSymbol.getDefaultValue() != null) {
if (methodSymbol.getDefaultValue() instanceof Attribute.Array) {
defaultValues = ((Attribute.Array) methodSymbol.getDefaultValue()).getValue().stream()
.map(attr -> attr.getValue().toString())
.collect(Collectors.toList());
} else {
try {
defaultValues = Collections.singletonList(methodSymbol.getDefaultValue().getValue().toString());
} catch(UnsupportedOperationException e) {
} catch (UnsupportedOperationException e) {
// not all Attribute implementations define `getValue()`
}
}
}

List<String> declaredFormalTypeNames = null;
for (Symbol.TypeVariableSymbol typeParam : methodSymbol.getTypeParameters()) {
if(typeParam.owner == methodSymbol) {
if (typeParam.owner == methodSymbol) {
if (declaredFormalTypeNames == null) {
declaredFormalTypeNames = new ArrayList<>();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,13 @@

import org.junit.jupiter.api.Test;
import org.openrewrite.java.MinimumJava21;
import org.openrewrite.test.RecipeSpec;
import org.openrewrite.test.RewriteTest;
import org.openrewrite.test.TypeValidation;

import static org.openrewrite.java.Assertions.java;

@MinimumJava21
class RecordPatternMatchingTest implements RewriteTest {

@Override
public void defaults(RecipeSpec spec) {
spec.typeValidationOptions(TypeValidation.all().unknown(false));
}

@Test
void shouldParseJava21PatternMatchForRecords() {
rewriteRun(
Expand All @@ -46,7 +39,8 @@ void printSum(Object obj) {
}
}
"""
));
)
);
}

@Test
Expand All @@ -68,7 +62,8 @@ void printColorOfUpperLeftPoint(Rectangle r) {
}
}
"""
));
)
);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -152,4 +152,27 @@ void integerTester(Integer i) {
));
}

@Test
void shouldParseRecordPatternMatchingInSwitch() {
rewriteRun(
java(
//language=java
"""
class Test {
public interface Printable {}
record A(String A) implements Printable {}
record B(Integer B) implements Printable {}
void integerTester(Printable prt) {
switch (prt) {
case A(String a) -> System.out.println(a);
case B(Integer b) -> System.out.println(b);
default -> throw new IllegalStateException("Unexpected value: " + prt);
}
}
}
"""
));
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,11 @@ public J.InstanceOf visitInstanceOf(J.InstanceOf instanceOf, P p) {
return (J.InstanceOf) super.visitInstanceOf(instanceOf, p);
}

@Override
public J.DeconstructionPattern visitDeconstructionPattern(J.DeconstructionPattern deconstructionPattern, P p) {
return (J.DeconstructionPattern) super.visitDeconstructionPattern(deconstructionPattern, p);
}

@Override
public J.IntersectionType visitIntersectionType(J.IntersectionType intersectionType, P p) {
return (J.IntersectionType) super.visitIntersectionType(intersectionType, p);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -732,6 +732,15 @@ public J visitInstanceOf(InstanceOf instanceOf, PrintOutputCapture<P> p) {
return instanceOf;
}

@Override
public J visitDeconstructionPattern(DeconstructionPattern deconstructionPattern, PrintOutputCapture<P> p) {
beforeSyntax(deconstructionPattern, Space.Location.DECONSTRUCTION_PATTERN_PREFIX, p);
visitAndCast(deconstructionPattern.getDeconstructor(), p);
visitContainer("(", deconstructionPattern.getPadding().getNested(), JContainer.Location.DECONSTRUCTION_PATTERN_NESTED, ",", ")", p);
afterSyntax(deconstructionPattern, p);
return deconstructionPattern;
}

@Override
public J visitIntersectionType(IntersectionType intersectionType, PrintOutputCapture<P> p) {
beforeSyntax(intersectionType, Space.Location.INTERSECTION_TYPE_PREFIX, p);
Expand Down
11 changes: 11 additions & 0 deletions rewrite-java/src/main/java/org/openrewrite/java/JavaVisitor.java
Original file line number Diff line number Diff line change
Expand Up @@ -745,6 +745,17 @@ public J visitInstanceOf(J.InstanceOf instanceOf, P p) {
return i;
}

public J visitDeconstructionPattern(J.DeconstructionPattern deconstructionPattern, P p) {
J.DeconstructionPattern d = deconstructionPattern;
d = d.withPrefix(visitSpace(d.getPrefix(), Space.Location.DECONSTRUCTION_PATTERN_PREFIX, p));
d = d.withMarkers(visitMarkers(d.getMarkers(), p));
d = d.withDeconstructor(visitAndCast(d.getDeconstructor(), p));
d = d.getPadding().withNested(visitContainer(d.getPadding().getNested(), JContainer.Location.DECONSTRUCTION_PATTERN_NESTED, p));
d = d.withType(visitType(d.getType(), p));
return d;

}

public J visitIntersectionType(J.IntersectionType intersectionType, P p) {
J.IntersectionType i = intersectionType;
i = i.withPrefix(visitSpace(i.getPrefix(), Space.Location.INTERSECTION_TYPE_PREFIX, p));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -816,6 +816,25 @@ public J.InstanceOf visitInstanceOf(J.InstanceOf instanceOf, J j) {
return instanceOf;
}

@Override
public J.DeconstructionPattern visitDeconstructionPattern(J.DeconstructionPattern deconstructionPattern, J j) {
if (isEqual.get()) {
if (!(j instanceof J.DeconstructionPattern)) {
isEqual.set(false);
return deconstructionPattern;
}

J.DeconstructionPattern compareTo = (J.DeconstructionPattern) j;
if (!TypeUtils.isOfType(deconstructionPattern.getType(), compareTo.getType())) {
isEqual.set(false);
return deconstructionPattern;
}
visit(deconstructionPattern.getDeconstructor(), compareTo.getDeconstructor());
this.visitList(deconstructionPattern.getNested(), compareTo.getNested());
}
return deconstructionPattern;
}

@Override
public J.Label visitLabel(J.Label label, J j) {
if (isEqual.get()) {
Expand Down
81 changes: 81 additions & 0 deletions rewrite-java/src/main/java/org/openrewrite/java/tree/J.java
Original file line number Diff line number Diff line change
Expand Up @@ -3047,6 +3047,87 @@ public InstanceOf withExpr(JRightPadded<Expression> expression) {
}
}

@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE)
@EqualsAndHashCode(callSuper = false, onlyExplicitlyIncluded = true)
@RequiredArgsConstructor
@AllArgsConstructor(access = AccessLevel.PRIVATE)
final class DeconstructionPattern implements J, TypedTree {

@Nullable
@NonFinal
transient WeakReference<Padding> padding;

@With
@EqualsAndHashCode.Include
@Getter
UUID id;

@With
@Getter
Space prefix;

@With
@Getter
Markers markers;

@With
@Getter
Expression deconstructor;

JContainer<J> nested;

public List<J> getNested() {
return nested.getElements();
}

public DeconstructionPattern withNested(List<J> nested) {
return getPadding().withNested(JContainer.withElements(this.nested, nested));
}

@Getter
@With
JavaType type;

@Override
public <P> J acceptJava(JavaVisitor<P> v, P p) {
return v.visitDeconstructionPattern(this, p);
}

@Override
public String toString() {
return withPrefix(Space.EMPTY).printTrimmed(new JavaPrinter<>());
}

public Padding getPadding() {
Padding p;
if (this.padding == null) {
p = new Padding(this);
this.padding = new WeakReference<>(p);
} else {
p = this.padding.get();
if (p == null || p.t != this) {
p = new Padding(this);
this.padding = new WeakReference<>(p);
}
}
return p;
}

@RequiredArgsConstructor
public static class Padding {
private final DeconstructionPattern t;

public JContainer<J> getNested() {
return t.nested;
}

public DeconstructionPattern withNested(JContainer<J> nested) {
return t.nested == nested ? t : new DeconstructionPattern(t.id, t.prefix, t.markers, t.deconstructor, nested, t.type);
}
}

}

@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE)
@EqualsAndHashCode(callSuper = false, onlyExplicitlyIncluded = true)
@RequiredArgsConstructor
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ public enum Location {
CASE(Space.Location.CASE, JRightPadded.Location.CASE),
CASE_EXPRESSION(Space.Location.CASE_EXPRESSION, JRightPadded.Location.CASE_EXPRESSION),
CASE_LABEL(Space.Location.CASE_LABEL, JRightPadded.Location.CASE_LABEL),
DECONSTRUCTION_PATTERN_NESTED(Space.Location.DECONSTRUCTION_PATTERN_NESTED, JRightPadded.Location.DECONSTRUCTION_PATTERN_NESTED),
IMPLEMENTS(Space.Location.IMPLEMENTS, JRightPadded.Location.IMPLEMENTS),
PERMITS(Space.Location.PERMITS, JRightPadded.Location.PERMITS),
LANGUAGE_EXTENSION(Space.Location.LANGUAGE_EXTENSION, JRightPadded.Location.LANGUAGE_EXTENSION),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ public enum Location {
IF_ELSE(Space.Location.IF_ELSE_SUFFIX),
IF_THEN(Space.Location.IF_THEN_SUFFIX),
IMPLEMENTS(Space.Location.IMPLEMENTS_SUFFIX),
DECONSTRUCTION_PATTERN_NESTED(Space.Location.DECONSTRUCTION_PATTERN_NESTED_SUFFIX),
PERMITS(Space.Location.PERMITS_SUFFIX),
IMPORT(Space.Location.IMPORT_SUFFIX),
INSTANCEOF(Space.Location.INSTANCEOF_SUFFIX),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ public static Space format(String formatting, int beginIndex, int toIndex) {

@SuppressWarnings("ConstantConditions")
public static <J2 extends J> @Nullable List<JRightPadded<J2>> formatLastSuffix(@Nullable List<JRightPadded<J2>> trees,
Space suffix) {
Space suffix) {
if (trees == null) {
return null;
}
Expand Down Expand Up @@ -336,6 +336,9 @@ public enum Location {
COMPILATION_UNIT_PREFIX,
CONTINUE_PREFIX,
CONTROL_PARENTHESES_PREFIX,
DECONSTRUCTION_PATTERN_PREFIX,
DECONSTRUCTION_PATTERN_NESTED,
DECONSTRUCTION_PATTERN_NESTED_SUFFIX,
DIMENSION,
DIMENSION_PREFIX,
DIMENSION_SUFFIX,
Expand Down

0 comments on commit c7a28a6

Please sign in to comment.