diff --git a/seatunnel-formats/seatunnel-format-avro/src/main/java/org/apache/seatunnel/format/avro/RowToAvroConverter.java b/seatunnel-formats/seatunnel-format-avro/src/main/java/org/apache/seatunnel/format/avro/RowToAvroConverter.java index 343d506efb5..6c2e19f2767 100644 --- a/seatunnel-formats/seatunnel-format-avro/src/main/java/org/apache/seatunnel/format/avro/RowToAvroConverter.java +++ b/seatunnel-formats/seatunnel-format-avro/src/main/java/org/apache/seatunnel/format/avro/RowToAvroConverter.java @@ -19,6 +19,7 @@ package org.apache.seatunnel.format.avro; import org.apache.seatunnel.api.table.type.ArrayType; +import org.apache.seatunnel.api.table.type.MapType; import org.apache.seatunnel.api.table.type.SeaTunnelDataType; import org.apache.seatunnel.api.table.type.SeaTunnelRow; import org.apache.seatunnel.api.table.type.SeaTunnelRowType; @@ -37,6 +38,8 @@ import java.lang.reflect.Array; import java.nio.ByteBuffer; import java.util.ArrayList; +import java.util.HashMap; +import java.util.Map; public class RowToAvroConverter implements Serializable { @@ -87,24 +90,27 @@ private Object resolveObject(Object data, SeaTunnelDataType seaTunnelDataType } switch (seaTunnelDataType.getSqlType()) { case STRING: - case SMALLINT: case INT: case BIGINT: case FLOAT: case DOUBLE: case BOOLEAN: - case MAP: case DECIMAL: case DATE: case TIMESTAMP: return data; case TINYINT: + case SMALLINT: Class typeClass = seaTunnelDataType.getTypeClass(); if (typeClass == Byte.class) { if (data instanceof Byte) { Byte aByte = (Byte) data; return Byte.toUnsignedInt(aByte); } + } else if (typeClass == Short.class) { + if (data instanceof Short) { + return ((Short) data).intValue(); + } } return data; case BYTES: @@ -118,6 +124,18 @@ private Object resolveObject(Object data, SeaTunnelDataType seaTunnelDataType records.add(resolveObject(Array.get(data, i), basicType)); } return records; + case MAP: + MapType mapType = (MapType) seaTunnelDataType; + SeaTunnelDataType keyType = mapType.getKeyType(); + SeaTunnelDataType valueType = mapType.getValueType(); + Map mapData = new HashMap<>(); + for (Map.Entry entry : ((Map) data).entrySet()) { + mapData.put( + resolveObject(entry.getKey(), keyType), + resolveObject(entry.getValue(), valueType)); + } + return mapData; + case ROW: SeaTunnelRow seaTunnelRow = (SeaTunnelRow) data; SeaTunnelDataType[] fieldTypes = diff --git a/seatunnel-formats/seatunnel-format-avro/src/test/java/org/apache/seatunnel/format/avro/AvroConverterTest.java b/seatunnel-formats/seatunnel-format-avro/src/test/java/org/apache/seatunnel/format/avro/AvroConverterTest.java index fb45a0b5377..60f7281331a 100644 --- a/seatunnel-formats/seatunnel-format-avro/src/test/java/org/apache/seatunnel/format/avro/AvroConverterTest.java +++ b/seatunnel-formats/seatunnel-format-avro/src/test/java/org/apache/seatunnel/format/avro/AvroConverterTest.java @@ -66,13 +66,16 @@ private SeaTunnelRow buildSeaTunnelRow() { subSeaTunnelRow.setField(12, bigDecimal); subSeaTunnelRow.setField(13, localDateTime); + Map mapData = new HashMap<>(); + mapData.put("k1", Short.valueOf("1")); + mapData.put("k2", Short.valueOf("2")); SeaTunnelRow seaTunnelRow = new SeaTunnelRow(15); - seaTunnelRow.setField(0, map); + seaTunnelRow.setField(0, mapData); seaTunnelRow.setField(1, strArray); seaTunnelRow.setField(2, "strVal"); seaTunnelRow.setField(3, true); - seaTunnelRow.setField(4, 1); - seaTunnelRow.setField(5, 2); + seaTunnelRow.setField(4, new Byte("1")); + seaTunnelRow.setField(5, Short.valueOf("2")); seaTunnelRow.setField(6, 3); seaTunnelRow.setField(7, Long.MAX_VALUE - 1); seaTunnelRow.setField(8, 33.333F); @@ -138,12 +141,12 @@ private SeaTunnelRowType buildSeaTunnelRowType() { "c_row" }; SeaTunnelDataType[] fieldTypes = { - new MapType<>(BasicType.STRING_TYPE, BasicType.STRING_TYPE), + new MapType<>(BasicType.STRING_TYPE, BasicType.SHORT_TYPE), ArrayType.STRING_ARRAY_TYPE, BasicType.STRING_TYPE, BasicType.BOOLEAN_TYPE, - BasicType.INT_TYPE, - BasicType.INT_TYPE, + BasicType.BYTE_TYPE, + BasicType.SHORT_TYPE, BasicType.INT_TYPE, BasicType.LONG_TYPE, BasicType.FLOAT_TYPE, diff --git a/seatunnel-formats/seatunnel-format-avro/src/test/java/org/apache/seatunnel/format/avro/AvroSerializationSchemaTest.java b/seatunnel-formats/seatunnel-format-avro/src/test/java/org/apache/seatunnel/format/avro/AvroSerializationSchemaTest.java index 52ba7d76e68..18ca6d06e08 100644 --- a/seatunnel-formats/seatunnel-format-avro/src/test/java/org/apache/seatunnel/format/avro/AvroSerializationSchemaTest.java +++ b/seatunnel-formats/seatunnel-format-avro/src/test/java/org/apache/seatunnel/format/avro/AvroSerializationSchemaTest.java @@ -48,8 +48,8 @@ class AvroSerializationSchemaTest { private SeaTunnelRow buildSeaTunnelRow() { SeaTunnelRow subSeaTunnelRow = new SeaTunnelRow(14); Map map = new HashMap<>(); - map.put("k1", "v1"); - map.put("k2", "v2"); + map.put("k1", "1"); + map.put("k2", "2"); String[] strArray = new String[] {"l1", "l2"}; byte byteVal = 100; subSeaTunnelRow.setField(0, map); @@ -67,13 +67,16 @@ private SeaTunnelRow buildSeaTunnelRow() { subSeaTunnelRow.setField(12, bigDecimal); subSeaTunnelRow.setField(13, localDateTime); + Map mapData = new HashMap<>(); + mapData.put("k1", Short.valueOf("1")); + mapData.put("k2", Short.valueOf("2")); SeaTunnelRow seaTunnelRow = new SeaTunnelRow(15); - seaTunnelRow.setField(0, map); + seaTunnelRow.setField(0, mapData); seaTunnelRow.setField(1, strArray); seaTunnelRow.setField(2, "strVal"); seaTunnelRow.setField(3, true); - seaTunnelRow.setField(4, 1); - seaTunnelRow.setField(5, 2); + seaTunnelRow.setField(4, new Byte("1")); + seaTunnelRow.setField(5, Short.valueOf("2")); seaTunnelRow.setField(6, 3); seaTunnelRow.setField(7, Long.MAX_VALUE - 1); seaTunnelRow.setField(8, 33.333F); @@ -138,12 +141,12 @@ private SeaTunnelRowType buildSeaTunnelRowType() { "c_row" }; SeaTunnelDataType[] fieldTypes = { - new MapType<>(BasicType.STRING_TYPE, BasicType.STRING_TYPE), + new MapType<>(BasicType.STRING_TYPE, BasicType.SHORT_TYPE), ArrayType.STRING_ARRAY_TYPE, BasicType.STRING_TYPE, BasicType.BOOLEAN_TYPE, - BasicType.INT_TYPE, - BasicType.INT_TYPE, + BasicType.BYTE_TYPE, + BasicType.SHORT_TYPE, BasicType.INT_TYPE, BasicType.LONG_TYPE, BasicType.FLOAT_TYPE,