Skip to content

Commit

Permalink
Introduce custom compiler option
Browse files Browse the repository at this point in the history
  • Loading branch information
RustedBones committed Sep 23, 2024
1 parent ee6fa40 commit fb8ba6e
Show file tree
Hide file tree
Showing 15 changed files with 158 additions and 94 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ libraryDependencies += "org.apache.avro" % "avro" % avroCompilerVersion

| Name | Default | Description |
|:-------------------------------------------|:----------------------------------------------|:----------------------------------------------------------------------------------------|
| `avroSource` | `sourceDirectory` / `avro` | Source directory with `*.avsc`, `*.avdl` and `*.avpr` files. |
| `avroSources` | `sourceDirectory` / `avro` | Source directories with `*.avsc`, `*.avdl` and `*.avpr` files. |
| `avroSpecificRecords` | `Seq.empty` | List of avro generated classes to recompile with current avro version and settings. |
| `avroSchemaParserBuilder` | `DefaultSchemaParserBuilder.default()` | `.avsc` schema parser builder |
| `avroUnpackDependencies` / `includeFilter` | All avro specifications | Avro specification files from dependencies to unpack |
Expand Down
8 changes: 8 additions & 0 deletions api/src/main/java/com/github/sbt/avro/AvroCompiler.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,14 @@
import java.io.File;

public interface AvroCompiler {

void setStringType(String stringType);
void setFieldVisibility(String fieldVisibility);
void setUseNamespace(boolean useNamespace);
void setEnableDecimalLogicalType(boolean enableDecimalLogicalType);
void setCreateSetters(boolean createSetters);
void setOptionalGetters(boolean optionalGetters);

void recompile(Class<?>[] records, File target) throws Exception;
void compileIdls(File[] idls, File target) throws Exception;
void compileAvscs(AvroFileRef[] avscs, File target) throws Exception;
Expand Down
66 changes: 38 additions & 28 deletions bridge/src/main/java/com/github/sbt/avro/AvroCompilerBridge.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import org.apache.avro.Schema;
import org.apache.avro.specific.SpecificRecord;
import xsbti.Logger;

import org.apache.avro.Protocol;
import org.apache.avro.compiler.idl.Idl;
Expand All @@ -16,39 +15,50 @@

public class AvroCompilerBridge implements AvroCompiler {

private final Logger logger;
private final StringType stringType;
private final FieldVisibility fieldVisibility;
private final boolean useNamespace;
private final boolean enableDecimalLogicalType;
private final boolean createSetters;
private final boolean optionalGetters;

public AvroCompilerBridge(
Logger logger,
String stringType,
String fieldVisibility,
boolean useNamespace,
boolean enableDecimalLogicalType,
boolean createSetters,
boolean optionalGetters
) {
this.logger = logger;
private StringType stringType;
private FieldVisibility fieldVisibility;
private boolean useNamespace;
private boolean enableDecimalLogicalType;
private boolean createSetters;
private boolean optionalGetters;

protected Schema.Parser createParser() {
return new Schema.Parser();
}

@Override
public void setStringType(String stringType) {
this.stringType = StringType.valueOf(stringType);
}

@Override
public void setFieldVisibility(String fieldVisibility) {
this.fieldVisibility = FieldVisibility.valueOf(fieldVisibility);
}

@Override
public void setUseNamespace(boolean useNamespace) {
this.useNamespace = useNamespace;
}

@Override
public void setEnableDecimalLogicalType(boolean enableDecimalLogicalType) {
this.enableDecimalLogicalType = enableDecimalLogicalType;
}

@Override
public void setCreateSetters(boolean createSetters) {
this.createSetters = createSetters;
this.optionalGetters = optionalGetters;
}

protected Schema.Parser createParser() {
return new Schema.Parser();
@Override
public void setOptionalGetters(boolean optionalGetters) {
this.optionalGetters = optionalGetters;
}

@Override
public void recompile(Class<?>[] records, File target) throws Exception {
AvscFilesCompiler compiler = new AvscFilesCompiler();
AvscFilesCompiler compiler = new AvscFilesCompiler(this::createParser);
compiler.setStringType(stringType);
compiler.setFieldVisibility(fieldVisibility);
compiler.setUseNamespace(useNamespace);
Expand All @@ -64,7 +74,7 @@ public void recompile(Class<?>[] records, File target) throws Exception {

Set<Class<? extends SpecificRecord>> classes = new HashSet<>();
for (Class<?> record : records) {
logger.info(() -> "Recompiling Avro record: " + record.getName());
System.out.println("Recompiling Avro record: " + record.getName());
classes.add((Class<? extends SpecificRecord>) record);
}
compiler.compileClasses(classes, target);
Expand All @@ -73,7 +83,7 @@ public void recompile(Class<?>[] records, File target) throws Exception {
@Override
public void compileIdls(File[] idls, File target) throws Exception {
for (File idl : idls) {
logger.info(() -> "Compiling Avro IDL " + idl);
System.out.println("Compiling Avro IDL " + idl);
Idl parser = new Idl(idl);
Protocol protocol = parser.CompilationUnit();
SpecificCompiler compiler = new SpecificCompiler(protocol);
Expand All @@ -93,7 +103,7 @@ public void compileIdls(File[] idls, File target) throws Exception {

@Override
public void compileAvscs(AvroFileRef[] avscs, File target) throws Exception {
AvscFilesCompiler compiler = new AvscFilesCompiler();
AvscFilesCompiler compiler = new AvscFilesCompiler(this::createParser);
compiler.setStringType(stringType);
compiler.setFieldVisibility(fieldVisibility);
compiler.setUseNamespace(useNamespace);
Expand All @@ -109,7 +119,7 @@ public void compileAvscs(AvroFileRef[] avscs, File target) throws Exception {

Set<AvroFileRef> files = new HashSet<>();
for (AvroFileRef ref: avscs) {
logger.info(() -> "Compiling Avro schema: " + ref.getFile());
System.out.println("Compiling Avro schema: " + ref.getFile());
files.add(ref);
}
compiler.compileFiles(Set.of(avscs), target);
Expand All @@ -118,7 +128,7 @@ public void compileAvscs(AvroFileRef[] avscs, File target) throws Exception {
@Override
public void compileAvprs(File[] avprs, File target) throws Exception {
for (File avpr : avprs) {
logger.info(() -> "Compiling Avro protocol " + avpr);
System.out.println("Compiling Avro protocol " + avpr);
Protocol protocol = Protocol.parse(avpr);
SpecificCompiler compiler = new SpecificCompiler(protocol);
compiler.setStringType(stringType);
Expand Down
10 changes: 6 additions & 4 deletions bridge/src/main/java/com/github/sbt/avro/AvscFilesCompiler.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@

public class AvscFilesCompiler {

private final Supplier<Schema.Parser> parserSupplier;
private Schema.Parser schemaParser;

private String templateDirectory;
private GenericData.StringType stringType;
private SpecificCompiler.FieldVisibility fieldVisibility;
Expand All @@ -26,9 +28,9 @@ public class AvscFilesCompiler {
private Boolean optionalGettersForNullableFieldsOnly;
private Map<AvroFileRef, Exception> compileExceptions;

public AvscFilesCompiler() {
// this.builder = builder;
this.schemaParser = new Schema.Parser(); //builder.build();
public AvscFilesCompiler(Supplier<Schema.Parser> parserSupplier) {
this.parserSupplier = parserSupplier;
this.schemaParser = parserSupplier.get();
}

public void compileFiles(Set<AvroFileRef> files, File outputDirectory) {
Expand Down Expand Up @@ -151,7 +153,7 @@ private boolean tryCompile(File src, Schema schema, File outputDirectory) {
private Schema.Parser stashParser() {
// on failure Schema.Parser changes cache state.
// We want last successful state.
Schema.Parser parser = new Schema.Parser(); // builder.build();
Schema.Parser parser = parserSupplier.get();
Set<String> predefinedTypes = parser.getTypes().keySet();
Map<String, Schema> compiledTypes = schemaParser.getTypes();
compiledTypes.keySet().removeAll(predefinedTypes);
Expand Down
1 change: 0 additions & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ lazy val `sbt-avro-compiler-bridge`: Project = project
autoScalaLibrary := false,
libraryDependencies ++= Seq(
Dependencies.Provided.AvroCompiler,
Dependencies.Provided.SbtUtilInterface,
Dependencies.Test.Specs2Core,
)
)
Expand Down
80 changes: 32 additions & 48 deletions plugin/src/main/scala/com/github/sbt/avro/SbtAvro.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,15 @@ object SbtAvro extends AutoPlugin {
import Defaults._

// format: off
val avroAdditionalDependencies = settingKey[Seq[ModuleID]]("Additional dependencies to be added to library dependencies.")
val avroCompilerClass = settingKey[String]("Sbt avro compiler class. Default: com.github.sbt.avro.AvroCompilerBridge")
val avroCreateSetters = settingKey[Boolean]("Generate setters. Default: true")
val avroDependencyIncludeFilter = settingKey[DependencyFilter]("Filter for including modules containing avro dependencies.")
val avroEnableDecimalLogicalType = settingKey[Boolean]("Use java.math.BigDecimal instead of java.nio.ByteBuffer for logical type \"decimal\". Default: true.")
val avroFieldVisibility = settingKey[String]("Field visibility for the properties. Possible values: private, public. Default: public.")
val avroIncludes = settingKey[Seq[File]]("Avro schema includes.")
val avroOptionalGetters = settingKey[Boolean]("Generate getters that return Optional for nullable fields. Default: false.")
val avroSpecificRecords = settingKey[Seq[String]]("List of avro records to recompile with current avro version and settings. Classes must be part of the Avro library dependencies.")
val avroSource = settingKey[File]("Default Avro source directory.")
val avroSources = settingKey[Seq[File]]("Avro source directories.")
val avroStringType = settingKey[String]("Type for representing strings. Possible values: CharSequence, String, Utf8. Default: CharSequence.")
val avroUnpackDependencies = taskKey[Seq[File]]("Unpack avro dependencies.")
val avroUseNamespace = settingKey[Boolean]("Validate that directory layout reflects namespaces, i.e. src/main/avro/com/myorg/MyRecord.avsc. Default: false.")
Expand All @@ -49,8 +50,17 @@ object SbtAvro extends AutoPlugin {
lazy val avroArtifactTasks: Seq[TaskKey[File]] = Seq(Compile, Test).map(_ / packageAvro)

lazy val defaultSettings: Seq[Setting[_]] = Seq(
// compiler
avroCompilerClass := "com.github.sbt.avro.AvroCompilerBridge",
avroCreateSetters := true,
avroEnableDecimalLogicalType := true,
avroFieldVisibility := "public",
avroOptionalGetters := false,
avroStringType := "CharSequence",
avroUseNamespace := false,

// dependency management
avroDependencyIncludeFilter := artifactFilter(`type` = Artifact.SourceType, classifier = AvroClassifier),
avroIncludes := Seq(),
// addArtifact doesn't take publishArtifact setting in account
artifacts ++= Classpaths.artifactDefs(avroArtifactTasks).value,
packagedArtifacts ++= Classpaths.packaged(avroArtifactTasks).value,
Expand All @@ -60,24 +70,17 @@ object SbtAvro extends AutoPlugin {
// setup avro configuration. Use library management to fetch the compiler
ivyConfigurations ++= Seq(Avro),
avroVersion := "1.12.0",
libraryDependencies ++= Seq(
avroAdditionalDependencies := Seq(
"com.github.sbt" % "sbt-avro-compiler-bridge" % BuildInfo.version % Avro,
"org.apache.avro" % "avro-compiler" % avroVersion.value % Avro,
"org.apache.avro" % "avro" % avroVersion.value,
)
)

lazy val avroScopedSettings: Seq[Setting[_]] = Seq(
managedClasspath := Classpaths.managedJars(
Avro,
classpathTypes.value,
update.value
)
),
libraryDependencies ++= avroAdditionalDependencies.value
)

// settings to be applied for both Compile and Test
lazy val configScopedSettings: Seq[Setting[_]] = Seq(
avroSource := sourceDirectory.value / "avro",
avroSources := Seq(sourceDirectory.value / "avro"),
avroSpecificRecords := Seq.empty,
// dependencies
avroUnpackDependencies / includeFilter := AllPassFilter,
Expand All @@ -101,24 +104,15 @@ object SbtAvro extends AutoPlugin {
import autoImport._

def packageAvroMappings: Def.Initialize[Task[Seq[(File, String)]]] = Def.task {
(avroSource.value ** AvroFilter) pair relativeTo(avroSource.value)
avroSources.value.flatMap(src => (src ** AvroFilter).pair(relativeTo(src)))
}

override def trigger: PluginTrigger = noTrigger

override def requires: Plugins = sbt.plugins.JvmPlugin

override lazy val globalSettings: Seq[Setting[_]] = Seq(
avroStringType := "CharSequence",
avroFieldVisibility := "public",
avroEnableDecimalLogicalType := true,
avroUseNamespace := false,
avroOptionalGetters := false,
avroCreateSetters := true
)

override lazy val projectSettings: Seq[Setting[_]] = defaultSettings ++
inConfig(Avro)(avroScopedSettings) ++
inConfig(Avro)(Defaults.configSettings) ++
Seq(Compile, Test).flatMap(c => inConfig(c)(configScopedSettings))

private def unpack(deps: Seq[File],
Expand Down Expand Up @@ -174,11 +168,10 @@ object SbtAvro extends AutoPlugin {

private def sourceGeneratorTask(key: TaskKey[Seq[File]]) = Def.task {
val out = (key / streams).value
val compilerClass = avroCompilerClass.value
val version = avroVersion.value
val srcDir = avroSource.value
val externalSrcDir = (avroUnpackDependencies / target).value
val includes = avroIncludes.value
val srcDirs = Seq(externalSrcDir, srcDir) ++ includes
val srcDirs = Seq(externalSrcDir) ++ avroSources.value
val outDir = (key / target).value
val strType = avroStringType.value
val fieldVis = avroFieldVisibility.value.toUpperCase
Expand Down Expand Up @@ -213,32 +206,23 @@ object SbtAvro extends AutoPlugin {
// TODO Cache class loader
val avroClassLoader = new URLClassLoader(
"AvroClassLoader",
(Avro / managedClasspath).value.map(_.data.toURI.toURL).toArray,
(Avro / dependencyClasspath).value.map(_.data.toURI.toURL).toArray,
this.getClass.getClassLoader
)

val compiler = avroClassLoader
.loadClass("com.github.sbt.avro.AvroCompilerBridge")
.getDeclaredConstructor(
classOf[Logger],
classOf[String],
classOf[String],
classOf[Boolean],
classOf[Boolean],
classOf[Boolean],
classOf[Boolean],
)
.newInstance(
out.log,
strType,
fieldVis,
Boolean.box(useNs),
Boolean.box(enbDecimal),
Boolean.box(createSetters),
Boolean.box(optionalGetters)
)
.loadClass(compilerClass)
.getDeclaredConstructor()
.newInstance()
.asInstanceOf[AvroCompiler]

compiler.setStringType(strType)
compiler.setFieldVisibility(fieldVis)
compiler.setUseNamespace(useNs)
compiler.setEnableDecimalLogicalType(enbDecimal)
compiler.setCreateSetters(createSetters)
compiler.setOptionalGetters(optionalGetters)

try {
val recs = records.map(avroClassLoader.loadClass)
val avdls = srcDirs.flatMap(d => (d ** AvroAvdlFilter).get)
Expand Down
20 changes: 20 additions & 0 deletions plugin/src/sbt-test/sbt-avro/avscparser/build.sbt
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@


lazy val parser = project
.in(file("parser"))
.settings(
crossPaths := false,
autoScalaLibrary := false,
libraryDependencies ++= Seq(
"com.github.sbt" % "sbt-avro-compiler-bridge" % sys.props("plugin.version"),
"org.apache.avro" % "avro-compiler" % "1.12.0"
)
)

lazy val root = project
.in(file("."))
.enablePlugins(SbtAvro)
.dependsOn(parser % "avro")
.settings(
avroCompilerClass := "com.github.sbt.avro.CustomAvroCompiler"
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package com.github.sbt.avro;

import org.apache.avro.Schema;
import org.apache.avro.SchemaBuilder;
import org.apache.avro.NameValidator;

import java.util.Collections;

public class CustomAvroCompiler extends AvroCompilerBridge {

@Override
protected Schema.Parser createParser() {
Schema.Parser parser = new Schema.Parser();
parser.setValidateDefaults(false);
Schema externalSchema = SchemaBuilder
.enumeration("B")
.namespace("com.github.sbt.avro.test")
.symbols("B1");
parser.addTypes(Collections.singletonList(externalSchema));
return parser;
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
sbt.version=1.3.0
5 changes: 5 additions & 0 deletions plugin/src/sbt-test/sbt-avro/avscparser/project/plugins.sbt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
sys.props.get("plugin.version") match {
case Some(x) => addSbtPlugin("com.github.sbt" % "sbt-avro" % x)
case _ => sys.error("""|The system property 'plugin.version' is not defined.
|Specify this property using the scriptedLaunchOpts -D.""".stripMargin)
}
Loading

0 comments on commit fb8ba6e

Please sign in to comment.