This is an automated email from the ASF dual-hosted git repository.

jark pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit 8f858f398de3c47fe5045fbe0f1497022bfb1c15
Author: Jark Wu <j...@apache.org>
AuthorDate: Thu Jun 4 20:01:47 2020 +0800

    [FLINK-18073][avro] Fix AvroRowDataSerializationSchema is not serializable
    
    This closes #12471
---
 .../formats/avro/AvroFileSystemFormatFactory.java  |  3 +-
 .../avro/AvroRowDataSerializationSchema.java       | 92 +++++++++++++++++-----
 .../avro/typeutils/AvroSchemaConverter.java        | 23 +++---
 .../avro/typeutils/AvroSchemaConverterTest.java    | 42 ++++++++++
 4 files changed, 127 insertions(+), 33 deletions(-)

diff --git 
a/flink-formats/flink-avro/src/main/java/org/apache/flink/formats/avro/AvroFileSystemFormatFactory.java
 
b/flink-formats/flink-avro/src/main/java/org/apache/flink/formats/avro/AvroFileSystemFormatFactory.java
index a033739..c60c42c 100644
--- 
a/flink-formats/flink-avro/src/main/java/org/apache/flink/formats/avro/AvroFileSystemFormatFactory.java
+++ 
b/flink-formats/flink-avro/src/main/java/org/apache/flink/formats/avro/AvroFileSystemFormatFactory.java
@@ -243,11 +243,12 @@ public class AvroFileSystemFormatFactory implements 
FileSystemFormatFactory {
                        BulkWriter<GenericRecord> writer = factory.create(out);
                        
AvroRowDataSerializationSchema.SerializationRuntimeConverter converter =
                                        
AvroRowDataSerializationSchema.createRowConverter(rowType);
+                       Schema schema = 
AvroSchemaConverter.convertToSchema(rowType);
                        return new BulkWriter<RowData>() {
 
                                @Override
                                public void addElement(RowData element) throws 
IOException {
-                                       GenericRecord record = (GenericRecord) 
converter.convert(element);
+                                       GenericRecord record = (GenericRecord) 
converter.convert(schema, element);
                                        writer.addElement(record);
                                }
 
diff --git 
a/flink-formats/flink-avro/src/main/java/org/apache/flink/formats/avro/AvroRowDataSerializationSchema.java
 
b/flink-formats/flink-avro/src/main/java/org/apache/flink/formats/avro/AvroRowDataSerializationSchema.java
index 00b7ac5..5b1fbbe 100644
--- 
a/flink-formats/flink-avro/src/main/java/org/apache/flink/formats/avro/AvroRowDataSerializationSchema.java
+++ 
b/flink-formats/flink-avro/src/main/java/org/apache/flink/formats/avro/AvroRowDataSerializationSchema.java
@@ -75,6 +75,11 @@ public class AvroRowDataSerializationSchema implements 
SerializationSchema<RowDa
        private final SerializationRuntimeConverter runtimeConverter;
 
        /**
+        * Avro serialization schema.
+        */
+       private transient Schema schema;
+
+       /**
         * Writer to serialize Avro record into a Avro bytes.
         */
        private transient DatumWriter<IndexedRecord> datumWriter;
@@ -99,7 +104,7 @@ public class AvroRowDataSerializationSchema implements 
SerializationSchema<RowDa
 
        @Override
        public void open(InitializationContext context) throws Exception {
-               final Schema schema = 
AvroSchemaConverter.convertToSchema(rowType);
+               this.schema = AvroSchemaConverter.convertToSchema(rowType);
                datumWriter = new SpecificDatumWriter<>(schema);
                arrayOutputStream = new ByteArrayOutputStream();
                encoder = EncoderFactory.get().binaryEncoder(arrayOutputStream, 
null);
@@ -109,7 +114,7 @@ public class AvroRowDataSerializationSchema implements 
SerializationSchema<RowDa
        public byte[] serialize(RowData row) {
                try {
                        // convert to record
-                       final GenericRecord record = (GenericRecord) 
runtimeConverter.convert(row);
+                       final GenericRecord record = (GenericRecord) 
runtimeConverter.convert(schema, row);
                        arrayOutputStream.reset();
                        datumWriter.write(record, encoder);
                        encoder.flush();
@@ -145,33 +150,43 @@ public class AvroRowDataSerializationSchema implements 
SerializationSchema<RowDa
         * to corresponding Avro data structures.
         */
        interface SerializationRuntimeConverter extends Serializable {
-               Object convert(Object object);
+               Object convert(Schema schema, Object object);
        }
 
        static SerializationRuntimeConverter createRowConverter(RowType 
rowType) {
                final SerializationRuntimeConverter[] fieldConverters = 
rowType.getChildren().stream()
                        .map(AvroRowDataSerializationSchema::createConverter)
                        .toArray(SerializationRuntimeConverter[]::new);
-               final Schema schema = 
AvroSchemaConverter.convertToSchema(rowType);
                final LogicalType[] fieldTypes = rowType.getFields().stream()
                        .map(RowType.RowField::getType)
                        .toArray(LogicalType[]::new);
+               final RowData.FieldGetter[] fieldGetters = new 
RowData.FieldGetter[fieldTypes.length];
+               for (int i = 0; i < fieldTypes.length; i++) {
+                       fieldGetters[i] = 
RowData.createFieldGetter(fieldTypes[i], i);
+               }
                final int length = rowType.getFieldCount();
 
-               return object -> {
+               return (schema, object) -> {
                        final RowData row = (RowData) object;
+                       final List<Schema.Field> fields = schema.getFields();
                        final GenericRecord record = new 
GenericData.Record(schema);
                        for (int i = 0; i < length; ++i) {
-                               record.put(i, 
fieldConverters[i].convert(RowData.get(row, i, fieldTypes[i])));
+                               final Schema.Field schemaField = fields.get(i);
+                               Object avroObject = fieldConverters[i].convert(
+                                       schemaField.schema(),
+                                       fieldGetters[i].getFieldOrNull(row));
+                               record.put(i, avroObject);
                        }
                        return record;
                };
        }
 
        private static SerializationRuntimeConverter 
createConverter(LogicalType type) {
+               final SerializationRuntimeConverter converter;
                switch (type.getTypeRoot()) {
                        case NULL:
-                               return object -> null;
+                               converter = (schema, object) -> null;
+                               break;
                        case BOOLEAN: // boolean
                        case INTEGER: // int
                        case INTERVAL_YEAR_MONTH: // long
@@ -181,39 +196,74 @@ public class AvroRowDataSerializationSchema implements 
SerializationSchema<RowDa
                        case DOUBLE: // double
                        case TIME_WITHOUT_TIME_ZONE: // int
                        case DATE: // int
-                               return avroObject -> avroObject;
+                               converter = (schema, object) -> object;
+                               break;
                        case CHAR:
                        case VARCHAR:
-                               return object -> new Utf8(object.toString());
+                               converter = (schema, object) -> new 
Utf8(object.toString());
+                               break;
                        case BINARY:
                        case VARBINARY:
-                               return object -> ByteBuffer.wrap((byte[]) 
object);
+                               converter = (schema, object) -> 
ByteBuffer.wrap((byte[]) object);
+                               break;
                        case TIMESTAMP_WITHOUT_TIME_ZONE:
-                               return object -> ((TimestampData) 
object).toTimestamp().getTime();
+                               converter = (schema, object) -> 
((TimestampData) object).toTimestamp().getTime();
+                               break;
                        case DECIMAL:
-                               return object -> ByteBuffer.wrap(((DecimalData) 
object).toUnscaledBytes());
+                               converter = (schema, object) -> 
ByteBuffer.wrap(((DecimalData) object).toUnscaledBytes());
+                               break;
                        case ARRAY:
-                               return createArrayConverter((ArrayType) type);
+                               converter = createArrayConverter((ArrayType) 
type);
+                               break;
                        case ROW:
-                               return createRowConverter((RowType) type);
+                               converter = createRowConverter((RowType) type);
+                               break;
                        case MAP:
                        case MULTISET:
-                               return createMapConverter(type);
+                               converter = createMapConverter(type);
+                               break;
                        case RAW:
                        default:
                                throw new 
UnsupportedOperationException("Unsupported type: " + type);
                }
+
+               // wrap into nullable converter
+               return (schema, object) -> {
+                       if (object == null) {
+                               return null;
+                       }
+
+                       // get actual schema if it is a nullable schema
+                       Schema actualSchema;
+                       if (schema.getType() == Schema.Type.UNION) {
+                               List<Schema> types = schema.getTypes();
+                               int size = types.size();
+                               if (size == 2 && types.get(1).getType() == 
Schema.Type.NULL) {
+                                       actualSchema = types.get(0);
+                               } else if (size == 2 && types.get(0).getType() 
== Schema.Type.NULL) {
+                                       actualSchema = types.get(1);
+                               } else {
+                                       throw new IllegalArgumentException(
+                                               "The Avro schema is not a 
nullable type: " + schema.toString());
+                               }
+                       } else {
+                               actualSchema = schema;
+                       }
+                       return converter.convert(actualSchema, object);
+               };
        }
 
        private static SerializationRuntimeConverter 
createArrayConverter(ArrayType arrayType) {
+               LogicalType elementType = arrayType.getElementType();
+               final ArrayData.ElementGetter elementGetter = 
ArrayData.createElementGetter(elementType);
                final SerializationRuntimeConverter elementConverter = 
createConverter(arrayType.getElementType());
-               final LogicalType elementType = arrayType.getElementType();
 
-               return object -> {
+               return (schema, object) -> {
+                       final Schema elementSchema = schema.getElementType();
                        ArrayData arrayData = (ArrayData) object;
                        List<Object> list = new ArrayList<>();
                        for (int i = 0; i < arrayData.size(); ++i) {
-                               
list.add(elementConverter.convert(ArrayData.get(arrayData, i, elementType)));
+                               
list.add(elementConverter.convert(elementSchema, 
elementGetter.getElementOrNull(arrayData, i)));
                        }
                        return list;
                };
@@ -221,16 +271,18 @@ public class AvroRowDataSerializationSchema implements 
SerializationSchema<RowDa
 
        private static SerializationRuntimeConverter 
createMapConverter(LogicalType type) {
                LogicalType valueType = extractValueTypeToAvroMap(type);
+               final ArrayData.ElementGetter valueGetter = 
ArrayData.createElementGetter(valueType);
                final SerializationRuntimeConverter valueConverter = 
createConverter(valueType);
 
-               return object -> {
+               return (schema, object) -> {
+                       final Schema valueSchema = schema.getValueType();
                        final MapData mapData = (MapData) object;
                        final ArrayData keyArray = mapData.keyArray();
                        final ArrayData valueArray = mapData.valueArray();
                        final Map<Object, Object> map = new 
HashMap<>(mapData.size());
                        for (int i = 0; i < mapData.size(); ++i) {
                                final String key = 
keyArray.getString(i).toString();
-                               final Object value = 
valueConverter.convert(ArrayData.get(valueArray, i, valueType));
+                               final Object value = 
valueConverter.convert(valueSchema, valueGetter.getElementOrNull(valueArray, 
i));
                                map.put(key, value);
                        }
                        return map;
diff --git 
a/flink-formats/flink-avro/src/main/java/org/apache/flink/formats/avro/typeutils/AvroSchemaConverter.java
 
b/flink-formats/flink-avro/src/main/java/org/apache/flink/formats/avro/typeutils/AvroSchemaConverter.java
index 774fadf..37745e5 100644
--- 
a/flink-formats/flink-avro/src/main/java/org/apache/flink/formats/avro/typeutils/AvroSchemaConverter.java
+++ 
b/flink-formats/flink-avro/src/main/java/org/apache/flink/formats/avro/typeutils/AvroSchemaConverter.java
@@ -31,6 +31,7 @@ import org.apache.flink.table.types.logical.LogicalTypeFamily;
 import org.apache.flink.table.types.logical.MapType;
 import org.apache.flink.table.types.logical.MultisetType;
 import org.apache.flink.table.types.logical.RowType;
+import org.apache.flink.table.types.logical.TimeType;
 import org.apache.flink.table.types.logical.TimestampType;
 import org.apache.flink.table.types.logical.utils.LogicalTypeChecks;
 import org.apache.flink.types.Row;
@@ -179,6 +180,7 @@ public class AvroSchemaConverter {
        }
 
        public static Schema convertToSchema(LogicalType logicalType, int 
rowTypeCounter) {
+               int precision;
                switch (logicalType.getTypeRoot()) {
                        case NULL:
                                return SchemaBuilder.builder().nullType();
@@ -201,20 +203,25 @@ public class AvroSchemaConverter {
                        case TIMESTAMP_WITHOUT_TIME_ZONE:
                                // use long to represents Timestamp
                                final TimestampType timestampType = 
(TimestampType) logicalType;
-                               int precision = timestampType.getPrecision();
+                               precision = timestampType.getPrecision();
                                org.apache.avro.LogicalType avroLogicalType;
                                if (precision <= 3) {
                                        avroLogicalType = 
LogicalTypes.timestampMillis();
                                } else {
-                                       throw new 
IllegalArgumentException("Avro Timestamp does not support Timestamp with 
precision: " +
-                                               precision +
-                                               ", it only supports precision 
of 3 or 9.");
+                                       throw new 
IllegalArgumentException("Avro does not support TIMESTAMP type " +
+                                               "with precision: " + precision 
+ ", it only supports precision less than 3.");
                                }
                                return 
avroLogicalType.addToSchema(SchemaBuilder.builder().longType());
                        case DATE:
                                // use int to represents Date
                                return 
LogicalTypes.date().addToSchema(SchemaBuilder.builder().intType());
                        case TIME_WITHOUT_TIME_ZONE:
+                               precision = ((TimeType) 
logicalType).getPrecision();
+                               if (precision > 3) {
+                                       throw new IllegalArgumentException(
+                                               "Avro does not support TIME 
type with precision: " + precision +
+                                               ", it only supports precision 
less than 3.");
+                               }
                                // use int to represents Time, we only support 
millisecond when deserialization
                                return 
LogicalTypes.timeMillis().addToSchema(SchemaBuilder.builder().intType());
                        case DECIMAL:
@@ -254,14 +261,6 @@ public class AvroSchemaConverter {
                                        .array()
                                        
.items(convertToSchema(arrayType.getElementType(), rowTypeCounter));
                        case RAW:
-                               // if the union type has more than 2 types, it 
will be recognized a generic type
-                               // see 
AvroRowDeserializationSchema#convertAvroType and 
AvroRowSerializationSchema#convertFlinkType
-                               return SchemaBuilder.builder().unionOf()
-                                       .nullType().and()
-                                       .booleanType().and()
-                                       .longType().and()
-                                       .doubleType()
-                                       .endUnion();
                        case TIMESTAMP_WITH_LOCAL_TIME_ZONE:
                        default:
                                throw new 
UnsupportedOperationException("Unsupported to derive Schema for type: " + 
logicalType);
diff --git 
a/flink-formats/flink-avro/src/test/java/org/apache/flink/formats/avro/typeutils/AvroSchemaConverterTest.java
 
b/flink-formats/flink-avro/src/test/java/org/apache/flink/formats/avro/typeutils/AvroSchemaConverterTest.java
index be0ddc4..fa499b7 100644
--- 
a/flink-formats/flink-avro/src/test/java/org/apache/flink/formats/avro/typeutils/AvroSchemaConverterTest.java
+++ 
b/flink-formats/flink-avro/src/test/java/org/apache/flink/formats/avro/typeutils/AvroSchemaConverterTest.java
@@ -22,9 +22,14 @@ import org.apache.flink.api.common.typeinfo.TypeInformation;
 import org.apache.flink.api.common.typeinfo.Types;
 import org.apache.flink.api.java.typeutils.RowTypeInfo;
 import org.apache.flink.formats.avro.generated.User;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.TableSchema;
+import org.apache.flink.table.types.logical.RowType;
 import org.apache.flink.types.Row;
 
+import org.junit.Rule;
 import org.junit.Test;
+import org.junit.rules.ExpectedException;
 
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertTrue;
@@ -34,6 +39,9 @@ import static org.junit.Assert.assertTrue;
  */
 public class AvroSchemaConverterTest {
 
+       @Rule
+       public ExpectedException thrown = ExpectedException.none();
+
        @Test
        public void testAvroClassConversion() {
                
validateUserSchema(AvroSchemaConverter.convertToTypeInfo(User.class));
@@ -45,6 +53,40 @@ public class AvroSchemaConverterTest {
                
validateUserSchema(AvroSchemaConverter.convertToTypeInfo(schema));
        }
 
+       @Test
+       public void testInvalidRawTypeAvroSchemaConversion() {
+               RowType rowType = (RowType) TableSchema.builder()
+                       .field("a", DataTypes.STRING())
+                       .field("b", 
DataTypes.RAW(Types.GENERIC(AvroSchemaConverterTest.class)))
+                       .build().toRowDataType().getLogicalType();
+               thrown.expect(UnsupportedOperationException.class);
+               thrown.expectMessage("Unsupported to derive Schema for type: 
RAW");
+               AvroSchemaConverter.convertToSchema(rowType);
+       }
+
+       @Test
+       public void testInvalidTimestampTypeAvroSchemaConversion() {
+               RowType rowType = (RowType) TableSchema.builder()
+                       .field("a", DataTypes.STRING())
+                       .field("b", DataTypes.TIMESTAMP(9))
+                       .build().toRowDataType().getLogicalType();
+               thrown.expect(IllegalArgumentException.class);
+               thrown.expectMessage("Avro does not support TIMESTAMP type with 
precision: 9, " +
+                       "it only supports precision less than 3.");
+               AvroSchemaConverter.convertToSchema(rowType);
+       }
+
+       @Test
+       public void testInvalidTimeTypeAvroSchemaConversion() {
+               RowType rowType = (RowType) TableSchema.builder()
+                       .field("a", DataTypes.STRING())
+                       .field("b", DataTypes.TIME(6))
+                       .build().toRowDataType().getLogicalType();
+               thrown.expect(IllegalArgumentException.class);
+               thrown.expectMessage("Avro does not support TIME type with 
precision: 6, it only supports precision less than 3.");
+               AvroSchemaConverter.convertToSchema(rowType);
+       }
+
        private void validateUserSchema(TypeInformation<?> actual) {
                final TypeInformation<Row> address = Types.ROW_NAMED(
                        new String[]{

Reply via email to