Skip to content

Commit

Permalink
Sateful schema compiler (#216)
Browse files Browse the repository at this point in the history
Allow schema files to depend on recompiled records
  • Loading branch information
RustedBones authored Nov 12, 2024
1 parent 7c7d142 commit 8c1038d
Show file tree
Hide file tree
Showing 10 changed files with 254 additions and 1,088 deletions.
120 changes: 62 additions & 58 deletions bridge/src/main/java/com/github/sbt/avro/AvroCompilerBridge.java
Original file line number Diff line number Diff line change
@@ -1,34 +1,34 @@
package com.github.sbt.avro;

import org.apache.avro.Schema;
import org.apache.avro.specific.SpecificRecord;

import org.apache.avro.specific.SpecificData;
import org.apache.avro.Protocol;
import org.apache.avro.compiler.idl.Idl;
import org.apache.avro.compiler.specific.SpecificCompiler;
import org.apache.avro.compiler.specific.SpecificCompiler.FieldVisibility;
import org.apache.avro.generic.GenericData.StringType;

import java.io.File;
import java.util.HashSet;
import java.util.Set;
import java.util.*;

public class AvroCompilerBridge implements AvroCompiler {

private static final AvroVersion AVRO_1_9_0 = new AvroVersion(1, 9, 0);
private static final AvroVersion AVRO_1_10_0 = new AvroVersion(1, 10, 0);

private final AvroVersion avroVersion = AvroVersion.getRuntimeVersion();
private final AvroVersion avroVersion;
protected AvscFilesParser parser;

private StringType stringType;
private FieldVisibility fieldVisibility;
private boolean useNamespace;
private boolean enableDecimalLogicalType;
private boolean createSetters;
private boolean optionalGetters;
protected StringType stringType;
protected FieldVisibility fieldVisibility;
protected boolean useNamespace;
protected boolean enableDecimalLogicalType;
protected boolean createSetters;
protected boolean optionalGetters;

protected Schema.Parser createParser() {
return new Schema.Parser();
public AvroCompilerBridge() {
avroVersion = AvroVersion.getRuntimeVersion();
parser = new AvscFilesParser();
}

@Override
Expand Down Expand Up @@ -61,12 +61,9 @@ public void setOptionalGetters(boolean optionalGetters) {
this.optionalGetters = optionalGetters;
}

@Override
public void recompile(Class<?>[] records, File target) throws Exception {
AvscFilesCompiler compiler = new AvscFilesCompiler(this::createParser);
protected void configureCompiler(SpecificCompiler compiler) {
compiler.setStringType(stringType);
compiler.setFieldVisibility(fieldVisibility);
compiler.setUseNamespace(useNamespace);
compiler.setEnableDecimalLogicalType(enableDecimalLogicalType);
compiler.setCreateSetters(createSetters);
if (avroVersion.compareTo(AVRO_1_9_0) >= 0) {
Expand All @@ -75,14 +72,20 @@ public void recompile(Class<?>[] records, File target) throws Exception {
if (avroVersion.compareTo(AVRO_1_10_0) >= 0) {
compiler.setOptionalGettersForNullableFieldsOnly(optionalGetters);
}
compiler.setTemplateDirectory("/org/apache/avro/compiler/specific/templates/java/classic/");
}

Set<Class<? extends SpecificRecord>> classes = new HashSet<>();
@Override
public void recompile(Class<?>[] records, File target) throws Exception {
List<Schema> schemas = new ArrayList<>(records.length);
for (Class<?> record : records) {
System.out.println("Recompiling Avro record: " + record.getName());
classes.add((Class<? extends SpecificRecord>) record);
Schema schema = SpecificData.get().getSchema(record);
schemas.add(schema);
SpecificCompiler compiler = new SpecificCompiler(schema);
configureCompiler(compiler);
compiler.compileToDestination(null, target);
}
compiler.compileClasses(classes, target);
parser.addTypes(schemas);
}

@Override
Expand All @@ -92,42 +95,32 @@ public void compileIdls(File[] idls, File target) throws Exception {
Idl parser = new Idl(idl);
Protocol protocol = parser.CompilationUnit();
SpecificCompiler compiler = new SpecificCompiler(protocol);
compiler.setStringType(stringType);
compiler.setFieldVisibility(fieldVisibility);
compiler.setEnableDecimalLogicalType(enableDecimalLogicalType);
compiler.setCreateSetters(createSetters);
if (avroVersion.compareTo(AVRO_1_9_0) >= 0) {
compiler.setGettersReturnOptional(optionalGetters);
}
if (avroVersion.compareTo(AVRO_1_10_0) >= 0) {
compiler.setOptionalGettersForNullableFieldsOnly(optionalGetters);
}
compiler.compileToDestination(null, target);
configureCompiler(compiler);
compiler.compileToDestination(idl, target);
}
}

@Override
public void compileAvscs(AvroFileRef[] avscs, File target) throws Exception {
AvscFilesCompiler compiler = new AvscFilesCompiler(this::createParser);
compiler.setStringType(stringType);
compiler.setFieldVisibility(fieldVisibility);
compiler.setUseNamespace(useNamespace);
compiler.setEnableDecimalLogicalType(enableDecimalLogicalType);
compiler.setCreateSetters(createSetters);
if (avroVersion.compareTo(AVRO_1_9_0) >= 0) {
compiler.setGettersReturnOptional(optionalGetters);
}
if (avroVersion.compareTo(AVRO_1_10_0) >= 0) {
compiler.setOptionalGettersForNullableFieldsOnly(optionalGetters);
}
compiler.setTemplateDirectory("/org/apache/avro/compiler/specific/templates/java/classic/");

Set<AvroFileRef> files = new HashSet<>();
List<AvroFileRef> files = new ArrayList<>(avscs.length);
for (AvroFileRef ref : avscs) {
System.out.println("Compiling Avro schema: " + ref.getFile());
files.add(ref);
}
compiler.compileFiles(files, target);
Map<AvroFileRef, Schema> schemas = parser.parseFiles(files);
if (useNamespace) {
for (Map.Entry<AvroFileRef, Schema> s: schemas.entrySet()) {
validateParsedSchema(s.getKey(), s.getValue());
}
}

for (Map.Entry<AvroFileRef, Schema> entry: schemas.entrySet()) {
File file = entry.getKey().getFile();
Schema schema = entry.getValue();
SpecificCompiler compiler = new SpecificCompiler(schema);
configureCompiler(compiler);
compiler.compileToDestination(file, target);
}
}

@Override
Expand All @@ -136,17 +129,28 @@ public void compileAvprs(File[] avprs, File target) throws Exception {
System.out.println("Compiling Avro protocol: " + avpr);
Protocol protocol = Protocol.parse(avpr);
SpecificCompiler compiler = new SpecificCompiler(protocol);
compiler.setStringType(stringType);
compiler.setFieldVisibility(fieldVisibility);
compiler.setEnableDecimalLogicalType(enableDecimalLogicalType);
compiler.setCreateSetters(createSetters);
if (avroVersion.compareTo(AVRO_1_9_0) >= 0) {
compiler.setGettersReturnOptional(optionalGetters);
}
if (avroVersion.compareTo(AVRO_1_10_0) >= 0) {
compiler.setOptionalGettersForNullableFieldsOnly(optionalGetters);
}
configureCompiler(compiler);
compiler.compileToDestination(null, target);
}
}

private void validateParsedSchema(AvroFileRef src, Schema schema) {
if (useNamespace) {
if (schema.getType() != Schema.Type.RECORD && schema.getType() != Schema.Type.ENUM) {
throw new SchemaGenerationException(String.format(
"Error compiling schema file %s. "
+ "Only one root RECORD or ENUM type is allowed per file.",
src
));
} else if (!src.pathToClassName().equals(schema.getFullName())) {
throw new SchemaGenerationException(String.format(
"Error compiling schema file %s. "
+ "File class name %s does not match record class name %s",
src,
src.pathToClassName(),
schema.getFullName()
));
}
}
}
}
Loading

0 comments on commit 8c1038d

Please sign in to comment.