diff --git a/src/main/java/io/aiven/kafka/connect/http/converter/RecordValueConverter.java b/src/main/java/io/aiven/kafka/connect/http/converter/RecordValueConverter.java index 0324f30..29a4425 100644 --- a/src/main/java/io/aiven/kafka/connect/http/converter/RecordValueConverter.java +++ b/src/main/java/io/aiven/kafka/connect/http/converter/RecordValueConverter.java @@ -16,23 +16,26 @@ package io.aiven.kafka.connect.http.converter; -import java.util.HashMap; -import java.util.LinkedHashMap; +import java.util.HashSet; +import java.util.List; import java.util.Map; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.stream.Collectors; import org.apache.kafka.connect.data.Struct; import org.apache.kafka.connect.errors.DataException; import org.apache.kafka.connect.sink.SinkRecord; public class RecordValueConverter { - + private static final ConcurrentHashMap, Converter> RUNTIME_CLASS_TO_CONVERTER_CACHE = + new ConcurrentHashMap<>(); private final JsonRecordValueConverter jsonRecordValueConverter = new JsonRecordValueConverter(); private final Map, Converter> converters = Map.of( - String.class, record -> (String) record.value(), - HashMap.class, jsonRecordValueConverter, - LinkedHashMap.class, jsonRecordValueConverter, - Struct.class, jsonRecordValueConverter + String.class, record -> (String) record.value(), + Map.class, jsonRecordValueConverter, + Struct.class, jsonRecordValueConverter ); interface Converter { @@ -40,12 +43,71 @@ interface Converter { } public String convert(final SinkRecord record) { - if (!converters.containsKey(record.value().getClass())) { + final Converter converter = getConverter(record); + return converter.convert(record); + } + + private Converter getConverter(final SinkRecord record) { + return RUNTIME_CLASS_TO_CONVERTER_CACHE.computeIfAbsent(record.value().getClass(), clazz -> { + final boolean directlyConvertible = converters.containsKey(clazz); + final List> convertibleByImplementedTypes = getAllSerializableImplementedInterfaces(clazz); + validateConvertibility(clazz, directlyConvertible, convertibleByImplementedTypes); + + Class implementedClazz = clazz; + if (!directlyConvertible) { + implementedClazz = convertibleByImplementedTypes.get(0); + } + return converters.get(implementedClazz); + }); + } + + private List> getAllSerializableImplementedInterfaces(final Class recordClazz) { + // caching the computation since querying implemented interfaces is expensive. + // The size of the cache is unlimited, but I don't think it's a problem + // since the number of different record classes is limited. + return getAllInterfaces(recordClazz).stream() + .filter(converters::containsKey) + .collect(Collectors.toList()); + } + + public static Set> getAllInterfaces(final Class clazz) { + final Set> interfaces = new HashSet<>(); + + for (final Class implementation : clazz.getInterfaces()) { + interfaces.add(implementation); + interfaces.addAll(getAllInterfaces(implementation)); + } + + if (clazz.getSuperclass() != null) { + interfaces.addAll(getAllInterfaces(clazz.getSuperclass())); + } + + return interfaces; + } + + private static void validateConvertibility( + final Class recordClazz, + final boolean directlyConvertible, + final List> convertibleByImplementedTypes + ) { + final boolean isConvertibleType = directlyConvertible || !convertibleByImplementedTypes.isEmpty(); + + if (!isConvertibleType) { + throw new DataException( + String.format( + "Record value must be a String, a Schema Struct or implement " + + "`java.util.Map`, but %s is given", + recordClazz)); + } + if (!directlyConvertible && convertibleByImplementedTypes.size() > 1) { + final String implementedTypes = convertibleByImplementedTypes.stream().map(Class::getSimpleName) + .collect(Collectors.joining(", ", "[", "]")); throw new DataException( - "Record value must be String, Schema Struct, LinkedHashMap or HashMap," - + " but " + record.value().getClass() + " is given"); + String.format( + "Record value must be only one of String, Schema Struct or implement " + + "`java.util.Map`, but %s matches multiple types: %s", + recordClazz, implementedTypes)); } - return converters.get(record.value().getClass()).convert(record); } } diff --git a/src/test/java/io/aiven/kafka/connect/http/converter/RecordValueConverterTest.java b/src/test/java/io/aiven/kafka/connect/http/converter/RecordValueConverterTest.java index 859070a..ad036b0 100644 --- a/src/test/java/io/aiven/kafka/connect/http/converter/RecordValueConverterTest.java +++ b/src/test/java/io/aiven/kafka/connect/http/converter/RecordValueConverterTest.java @@ -16,10 +16,13 @@ package io.aiven.kafka.connect.http.converter; +import javax.swing.UIDefaults; + import java.util.HashMap; import java.util.LinkedHashMap; import java.util.Map; + import org.apache.kafka.connect.data.Field; import org.apache.kafka.connect.data.Schema; import org.apache.kafka.connect.data.SchemaBuilder; @@ -68,6 +71,22 @@ void convertStringRecord() { assertThat(recordValueConverter.convert(sinkRecord)).isEqualTo("some-str-value"); } + @Test + void convertWeirdMapRecord() { + final var recordSchema = SchemaBuilder.map(Schema.STRING_SCHEMA, Schema.STRING_SCHEMA); + + final UIDefaults value = new UIDefaults( + new String[] {"Font", "BeautifulFont"} + ); + + final var sinkRecord = new SinkRecord( + "some-topic", 0, + SchemaBuilder.string(), + "some-key", recordSchema, value, 1L); + + assertThat(recordValueConverter.convert(sinkRecord)).isEqualTo("{\"Font\":\"BeautifulFont\"}"); + } + @Test void convertHashMapRecord() { final var recordSchema = SchemaBuilder.map(Schema.STRING_SCHEMA, Schema.STRING_SCHEMA);