diff --git a/hive/src/main/java/com/twitter/elephantbird/hive/serde/ProtobufDeserializer.java b/hive/src/main/java/com/twitter/elephantbird/hive/serde/ProtobufDeserializer.java index d1014fe3d..ddbcafbc5 100644 --- a/hive/src/main/java/com/twitter/elephantbird/hive/serde/ProtobufDeserializer.java +++ b/hive/src/main/java/com/twitter/elephantbird/hive/serde/ProtobufDeserializer.java @@ -1,8 +1,8 @@ package com.twitter.elephantbird.hive.serde; +import com.google.protobuf.InvalidProtocolBufferException; import com.google.protobuf.Message; import com.google.protobuf.Descriptors.Descriptor; -import com.twitter.elephantbird.mapreduce.io.ProtobufConverter; import com.twitter.elephantbird.util.Protobufs; import org.apache.hadoop.conf.Configuration; @@ -31,7 +31,7 @@ */ public class ProtobufDeserializer implements Deserializer { - private ProtobufConverter protobufConverter = null; + private Message.Builder msgBuilder; private ObjectInspector objectInspector; @Override @@ -42,8 +42,7 @@ public void initialize(Configuration job, Properties tbl) throws SerDeException Class protobufClass = job.getClassByName(protoClassName) .asSubclass(Message.class); - protobufConverter = ProtobufConverter.newInstance(protobufClass); - + msgBuilder = Protobufs.getMessageBuilder(protobufClass); Descriptor descriptor = Protobufs.getMessageDescriptor(protobufClass); objectInspector = new ProtobufStructObjectInspector(descriptor); } catch (Exception e) { @@ -54,7 +53,11 @@ public void initialize(Configuration job, Properties tbl) throws SerDeException @Override public Object deserialize(Writable blob) throws SerDeException { BytesWritable bytes = (BytesWritable) blob; - return protobufConverter.fromBytes(bytes.getBytes(), 0, bytes.getLength()); + try { + return msgBuilder.clear().mergeFrom(bytes.getBytes(), 0, bytes.getLength()); + } catch (InvalidProtocolBufferException e) { + throw new SerDeException(e); + } } @Override diff --git a/hive/src/main/java/com/twitter/elephantbird/hive/serde/ProtobufStructObjectInspector.java b/hive/src/main/java/com/twitter/elephantbird/hive/serde/ProtobufStructObjectInspector.java index 8f5977bab..8054ea1ae 100644 --- a/hive/src/main/java/com/twitter/elephantbird/hive/serde/ProtobufStructObjectInspector.java +++ b/hive/src/main/java/com/twitter/elephantbird/hive/serde/ProtobufStructObjectInspector.java @@ -9,6 +9,7 @@ import com.google.protobuf.Descriptors.FieldDescriptor.JavaType; import com.google.protobuf.Descriptors.FieldDescriptor.Type; import com.google.protobuf.ByteString; +import com.google.protobuf.DynamicMessage; import com.google.protobuf.Message; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; @@ -30,6 +31,21 @@ public static class ProtobufStructField implements StructField { public ProtobufStructField(FieldDescriptor fieldDescriptor) { this.fieldDescriptor = fieldDescriptor; oi = this.createOIForField(); + comment = fieldDescriptor.getFullName(); + } + + @Override + public boolean equals(Object obj) { + if (obj instanceof ProtobufStructField) { + ProtobufStructField other = (ProtobufStructField)obj; + return fieldDescriptor.equals(other.fieldDescriptor); + } + return false; + } + + @Override + public int hashCode() { + return fieldDescriptor.hashCode(); } @Override @@ -108,6 +124,21 @@ private ObjectInspector createOIForField() { } } + @Override + public boolean equals(Object obj) { + if (obj instanceof ProtobufStructObjectInspector) { + ProtobufStructObjectInspector other = (ProtobufStructObjectInspector)obj; + return this.descriptor.equals(other.descriptor) && + this.structFields.equals(other.structFields); + } + return false; + } + + @Override + public int hashCode() { + return descriptor.hashCode(); + } + @Override public Category getCategory() { return Category.STRUCT; @@ -132,15 +163,32 @@ public String getTypeName() { @Override public Object create() { - return descriptor.toProto().toBuilder().build(); + return DynamicMessage.newBuilder(descriptor); } @Override public Object setStructFieldData(Object data, StructField field, Object fieldValue) { - return ((Message) data) - .toBuilder() - .setField(descriptor.findFieldByName(field.getFieldName()), fieldValue) - .build(); + DynamicMessage.Builder builder = (DynamicMessage.Builder)data; + ProtobufStructField psf = (ProtobufStructField)field; + FieldDescriptor fd = psf.getFieldDescriptor(); + if (fd.isRepeated()) { + return builder.setField(fd, fieldValue); + } + switch (fd.getType()) { + case ENUM: + builder.setField(fd, fd.getEnumType().findValueByName((String) fieldValue)); + break; + case BYTES: + builder.setField(fd, ByteString.copyFrom((byte[])fieldValue)); + break; + case MESSAGE: + builder.setField(fd, ((Message.Builder)fieldValue).build()); + break; + default: + builder.setField(fd, fieldValue); + break; + } + return builder; } @Override @@ -153,16 +201,32 @@ public Object getStructFieldData(Object data, StructField structField) { if (data == null) { return null; } - Message m = (Message) data; + Message.Builder builder; + if (data instanceof Message.Builder) { + builder = (Message.Builder)data; + } else if (data instanceof Message) { + builder = ((Message)data).toBuilder(); + } else { + throw new RuntimeException("Type Message or Message.Builder expected: " + + data.getClass().getCanonicalName()); + } ProtobufStructField psf = (ProtobufStructField) structField; FieldDescriptor fieldDescriptor = psf.getFieldDescriptor(); - Object result = m.getField(fieldDescriptor); + Object result = builder.getField(fieldDescriptor); + + if (fieldDescriptor.isRepeated()) { + return result; + } + if (fieldDescriptor.getType() == Type.ENUM) { return ((EnumValueDescriptor)result).getName(); } if (fieldDescriptor.getType() == Type.BYTES && (result instanceof ByteString)) { return ((ByteString)result).toByteArray(); } + if (fieldDescriptor.getType() == Type.MESSAGE) { + return ((Message)result).toBuilder(); + } return result; } @@ -177,9 +241,8 @@ public List getStructFieldsDataAsList(Object data) { return null; } List result = Lists.newArrayList(); - Message m = (Message) data; - for (FieldDescriptor fd : descriptor.getFields()) { - result.add(m.getField(fd)); + for (StructField field : structFields) { + result.add(getStructFieldData(data, field)); } return result; } diff --git a/hive/src/test/java/com/twitter/elephantbird/hive/serde/ProtobufDeserializerTest.java b/hive/src/test/java/com/twitter/elephantbird/hive/serde/ProtobufDeserializerTest.java index 1b81d0f70..3ecba75ea 100644 --- a/hive/src/test/java/com/twitter/elephantbird/hive/serde/ProtobufDeserializerTest.java +++ b/hive/src/test/java/com/twitter/elephantbird/hive/serde/ProtobufDeserializerTest.java @@ -1,5 +1,6 @@ package com.twitter.elephantbird.hive.serde; +import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; @@ -61,7 +62,7 @@ public void setUp() throws SerDeException { @Test public final void testDeserializer() throws SerDeException { BytesWritable serialized = new BytesWritable(test_ab.toByteArray()); - AddressBook ab2 = (AddressBook) deserializer.deserialize(serialized); + AddressBook ab2 = ((AddressBook.Builder) deserializer.deserialize(serialized)).build(); assertTrue(test_ab.equals(ab2)); } @@ -71,12 +72,13 @@ public final void testObjectInspector() throws SerDeException { assertEquals(oi.getCategory(), Category.STRUCT); ProtobufStructObjectInspector protobufOI = (ProtobufStructObjectInspector) oi; - List readData = protobufOI.getStructFieldsDataAsList(test_ab); + + List readData = protobufOI.getStructFieldsDataAsList(test_ab.toBuilder()); assertEquals(readData.size(), 2); @SuppressWarnings("unchecked") - ByteString byteStr = (ByteString)readData.get(1); - assertEquals(byteStr, ByteString.copyFrom(new byte[] {16,32,64,(byte) 128})); + byte[] byteStr = (byte[])readData.get(1); + assertArrayEquals(new byte[] {16,32,64,(byte) 128}, byteStr); List persons = (List) readData.get(0); assertEquals(persons.size(), 3); assertEquals(persons.get(0).getPhoneCount(), 3); @@ -101,7 +103,7 @@ public final void testObjectInspectorGetStructFieldData() throws SerDeException private void checkFields(List fields, Message message) { for (FieldDescriptor fieldDescriptor : fields) { ProtobufStructField psf = new ProtobufStructField(fieldDescriptor); - Object data = protobufOI.getStructFieldData(message, psf); + Object data = protobufOI.getStructFieldData(message.toBuilder(), psf); Object target = message.getField(fieldDescriptor); if (fieldDescriptor.getType() == Type.ENUM) { assertEquals(String.class, data.getClass());