Skip to content

Commit

Permalink
AVRO-1827: protobuf
Browse files Browse the repository at this point in the history
  • Loading branch information
clesaec committed Jul 12, 2023
1 parent 895c3db commit b20af31
Show file tree
Hide file tree
Showing 10 changed files with 1,836 additions and 515 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import org.apache.avro.io.DatumReader;
import org.apache.avro.io.DatumWriter;

import com.fasterxml.jackson.databind.node.NullNode;
import com.google.protobuf.ByteString;
import com.google.protobuf.Message;
import com.google.protobuf.Message.Builder;
Expand Down Expand Up @@ -87,15 +88,11 @@ public Object getField(Object r, String name, int pos) {

@Override
protected void setField(Object record, String name, int position, Object value, Object state) {
Builder b = (Builder) record;
FieldDescriptor f = ((FieldDescriptor[]) state)[position];
switch (f.getType()) {
case MESSAGE:
if (value == null) {
b.clearField(f);
break;
}
default:
final Builder b = (Builder) record;
final FieldDescriptor f = ((FieldDescriptor[]) state)[position];
if (value == null) {
b.clearField(f);
} else {
b.setField(f, value);
}
}
Expand All @@ -104,11 +101,9 @@ protected void setField(Object record, String name, int position, Object value,
protected Object getField(Object record, String name, int pos, Object state) {
Message m = (Message) record;
FieldDescriptor f = ((FieldDescriptor[]) state)[pos];
switch (f.getType()) {
case MESSAGE:
if (!f.isRepeated() && !m.hasField(f))
return null;
default:
if (!f.isRepeated() && !m.hasField(f) && !f.hasDefaultValue()) {
return null;
} else {
return m.getField(f);
}
}
Expand All @@ -133,6 +128,11 @@ protected boolean isRecord(Object datum) {
return datum instanceof Message;
}

@Override
protected boolean isEnum(final Object datum) {
return datum instanceof EnumValueDescriptor;
}

@Override
public Object newRecord(Object old, Schema schema) {
try {
Expand All @@ -148,6 +148,11 @@ public Object newRecord(Object old, Schema schema) {
}
}

@Override
protected Schema getEnumSchema(final Object enumObj) {
return getSchema(((EnumValueDescriptor) enumObj).getType());
}

@Override
protected boolean isArray(Object datum) {
return datum instanceof List;
Expand Down Expand Up @@ -274,46 +279,57 @@ public Schema getSchema(FieldDescriptor f) {
}

private Schema getNonRepeatedSchema(FieldDescriptor f) {
Schema result;
switch (f.getType()) {
case BOOL:
return Schema.create(Schema.Type.BOOLEAN);
return getSchema(Schema.create(Schema.Type.BOOLEAN), f);
case FLOAT:
return Schema.create(Schema.Type.FLOAT);
return getSchema(Schema.create(Schema.Type.FLOAT), f);
case DOUBLE:
return Schema.create(Schema.Type.DOUBLE);
return getSchema(Schema.create(Schema.Type.DOUBLE), f);
case STRING:
Schema s = Schema.create(Schema.Type.STRING);
GenericData.setStringType(s, GenericData.StringType.String);
return s;
return getSchema(s, f);
case BYTES:
return Schema.create(Schema.Type.BYTES);
return getSchema(Schema.create(Schema.Type.BYTES), f);
case INT32:
case UINT32:
case SINT32:
case FIXED32:
case SFIXED32:
return Schema.create(Schema.Type.INT);
return getSchema(Schema.create(Schema.Type.INT), f);
case INT64:
case UINT64:
case SINT64:
case FIXED64:
case SFIXED64:
return Schema.create(Schema.Type.LONG);
return getSchema(Schema.create(Schema.Type.LONG), f);
case ENUM:
return getSchema(f.getEnumType());
return getSchema(getSchema(f.getEnumType()), f);
case MESSAGE:
result = getSchema(f.getMessageType());
if (f.isOptional())
// wrap optional record fields in a union with null
result = Schema.createUnion(Arrays.asList(NULL, result));
return result;
return getSchema(getSchema(f.getMessageType()), f);
case GROUP: // groups are deprecated
default:
throw new RuntimeException("Unexpected type: " + f.getType());
}
}

private Schema getSchema(Schema schema, FieldDescriptor f) {
if (f.isOptional() && !f.hasDefaultValue()) {
// wrap optional record fields in a union with null
JsonNode defaultValue = this.getDefault(f);
final List<Schema> subSchemas;
if (defaultValue == NullNode.getInstance()) {
subSchemas = Arrays.asList(NULL, schema);
} else {
subSchemas = Arrays.asList(schema, NULL);
}
return Schema.createUnion(subSchemas);
} else {
return schema;
}
}

public Schema getSchema(EnumDescriptor d) {
List<String> symbols = new ArrayList<>(d.getValues().size());
for (EnumValueDescriptor e : d.getValues()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,14 @@

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.util.Arrays;

import org.apache.avro.Schema;
import org.apache.avro.generic.GenericData;
import org.apache.avro.generic.GenericDatumReader;
import org.apache.avro.generic.GenericEnumSymbol;
import org.apache.avro.generic.GenericRecord;
import org.apache.avro.io.Decoder;
import org.apache.avro.io.DecoderFactory;
import org.apache.avro.io.Encoder;
import org.apache.avro.io.EncoderFactory;
Expand All @@ -30,14 +36,30 @@

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotEquals;
import static org.junit.jupiter.api.Assertions.assertNull;

import com.google.protobuf.ByteString;

import org.apache.avro.protobuf.noopt.Test.Foo;
import org.apache.avro.protobuf.noopt.Test.A;
import org.apache.avro.protobuf.noopt.Test.M;
import org.apache.avro.protobuf.noopt.Test.M.N;

public class TestProtobuf {

protected <T> GenericRecord convertProtoToAvro(T objToConvert, Class clazz) throws Exception {
ByteArrayOutputStream bao = new ByteArrayOutputStream();
ProtobufDatumWriter<T> w = new ProtobufDatumWriter<T>(clazz);
Schema schema = ProtobufData.get().getSchema(clazz);
Encoder e = EncoderFactory.get().jsonEncoder(schema, bao);
w.write(objToConvert, e);
e.flush();
GenericDatumReader gdr = new GenericDatumReader(schema, schema);
Decoder d = DecoderFactory.get().jsonDecoder(schema, new ByteArrayInputStream(bao.toByteArray()));

return (GenericRecord) gdr.read(null, d);
}

@Test
void message() throws Exception {

Expand Down Expand Up @@ -146,4 +168,59 @@ void getNonRepeatedSchemaWithLogicalType() throws Exception {
Schema s2 = instance2.getSchema(com.google.protobuf.Timestamp.class);
assertEquals(conversion.getRecommendedSchema(), s2);
}

@Test
void nestedEnumWithValue() throws Exception {
Schema enumSchema = Schema.createEnum("N", null, null, Arrays.asList("A"));
GenericEnumSymbol enumA = new GenericData.EnumSymbol(enumSchema, "A");

M.Builder builder = M.newBuilder();
builder.setEnumN(M.N.A);

GenericRecord converted = convertProtoToAvro(builder.build(), M.class);

assertEquals(0, ((GenericEnumSymbol) converted.get("enumN")).compareTo(enumA));
}

@Test
void nestedEnumWithNull() throws Exception {
M.Builder builder = M.newBuilder();

GenericRecord converted = convertProtoToAvro(builder.build(), M.class);

assertNull(converted.get("enumN"));
}

@Test
void handlingOptionalValuesCorrectly() throws Exception {
Schema enumSchema = Schema.createEnum("A", null, null, Arrays.asList("X", "Y", "Z"));
GenericEnumSymbol enumZ = new GenericData.EnumSymbol(enumSchema, "Z");

Foo.Builder builder = Foo.newBuilder();
builder.setInt32(10);
builder.setInt64(2);
Foo foo = builder.build();

GenericRecord converted = convertProtoToAvro(foo, Foo.class);

assertEquals(10, converted.get("int32"));
assertEquals(2L, converted.get("int64"));
assertNull(converted.get("uint32"));
assertNull(converted.get("uint64"));
assertNull(converted.get("sint32"));
assertNull(converted.get("sint64"));
assertNull(converted.get("fixed32"));
assertNull(converted.get("fixed64"));
assertNull(converted.get("sfixed32"));
assertNull(converted.get("sfixed64"));
assertNull(converted.get("float"));
assertNull(converted.get("double"));
assertNull(converted.get("bool"));
assertNull(converted.get("string"));
assertNull(converted.get("bytes"));
assertEquals(0, ((GenericEnumSymbol) converted.get("enum")).compareTo(enumZ));
assertEquals(0, ((GenericData.Array) converted.get("intArray")).size());
assertEquals(0, ((GenericData.Array) converted.get("fooArray")).size());
assertEquals(0, ((GenericData.Array) converted.get("syms")).size());
}
}

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit b20af31

Please sign in to comment.