Repository: spark
Updated Branches:
  refs/heads/master 8812746d4 -> 2eaf05878


[SPARK-25718][SQL] Detect recursive reference in Avro schema and throw exception

## What changes were proposed in this pull request?

Avro schema allows recursive reference, e.g. the schema for linked-list in 
https://avro.apache.org/docs/1.8.2/spec.html#schema_record
```
{
  "type": "record",
  "name": "LongList",
  "aliases": ["LinkedLongs"],                      // old name for this
  "fields" : [
    {"name": "value", "type": "long"},             // each element has a long
    {"name": "next", "type": ["null", "LongList"]} // optional next element
  ]
}
```

In current Spark SQL, it is impossible to convert the schema as `StructType` . 
Run `SchemaConverters.toSqlType(avroSchema)` and we will get stack overflow 
exception.

We should detect the recursive reference and throw exception for it.
## How was this patch tested?

New unit test case.

Closes #22709 from gengliangwang/avroRecursiveRef.

Authored-by: Gengliang Wang <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/2eaf0587
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/2eaf0587
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/2eaf0587

Branch: refs/heads/master
Commit: 2eaf0587883ac3c65e77d01ffbb39f64c6152f87
Parents: 8812746
Author: Gengliang Wang <[email protected]>
Authored: Sat Oct 13 14:49:38 2018 +0800
Committer: Wenchen Fan <[email protected]>
Committed: Sat Oct 13 14:49:38 2018 +0800

----------------------------------------------------------------------
 .../spark/sql/avro/SchemaConverters.scala       | 26 +++++---
 .../org/apache/spark/sql/avro/AvroSuite.scala   | 65 ++++++++++++++++++++
 2 files changed, 84 insertions(+), 7 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/2eaf0587/external/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala
----------------------------------------------------------------------
diff --git 
a/external/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala 
b/external/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala
index bd15765..64127af 100644
--- 
a/external/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala
+++ 
b/external/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala
@@ -43,6 +43,10 @@ object SchemaConverters {
    * This function takes an avro schema and returns a sql schema.
    */
   def toSqlType(avroSchema: Schema): SchemaType = {
+    toSqlTypeHelper(avroSchema, Set.empty)
+  }
+
+  def toSqlTypeHelper(avroSchema: Schema, existingRecordNames: Set[String]): 
SchemaType = {
     avroSchema.getType match {
       case INT => avroSchema.getLogicalType match {
         case _: Date => SchemaType(DateType, nullable = false)
@@ -67,21 +71,28 @@ object SchemaConverters {
       case ENUM => SchemaType(StringType, nullable = false)
 
       case RECORD =>
+        if (existingRecordNames.contains(avroSchema.getFullName)) {
+          throw new IncompatibleSchemaException(s"""
+            |Found recursive reference in Avro schema, which can not be 
processed by Spark:
+            |${avroSchema.toString(true)}
+          """.stripMargin)
+        }
+        val newRecordNames = existingRecordNames + avroSchema.getFullName
         val fields = avroSchema.getFields.asScala.map { f =>
-          val schemaType = toSqlType(f.schema())
+          val schemaType = toSqlTypeHelper(f.schema(), newRecordNames)
           StructField(f.name, schemaType.dataType, schemaType.nullable)
         }
 
         SchemaType(StructType(fields), nullable = false)
 
       case ARRAY =>
-        val schemaType = toSqlType(avroSchema.getElementType)
+        val schemaType = toSqlTypeHelper(avroSchema.getElementType, 
existingRecordNames)
         SchemaType(
           ArrayType(schemaType.dataType, containsNull = schemaType.nullable),
           nullable = false)
 
       case MAP =>
-        val schemaType = toSqlType(avroSchema.getValueType)
+        val schemaType = toSqlTypeHelper(avroSchema.getValueType, 
existingRecordNames)
         SchemaType(
           MapType(StringType, schemaType.dataType, valueContainsNull = 
schemaType.nullable),
           nullable = false)
@@ -91,13 +102,14 @@ object SchemaConverters {
           // In case of a union with null, eliminate it and make a recursive 
call
           val remainingUnionTypes = 
avroSchema.getTypes.asScala.filterNot(_.getType == NULL)
           if (remainingUnionTypes.size == 1) {
-            toSqlType(remainingUnionTypes.head).copy(nullable = true)
+            toSqlTypeHelper(remainingUnionTypes.head, 
existingRecordNames).copy(nullable = true)
           } else {
-            
toSqlType(Schema.createUnion(remainingUnionTypes.asJava)).copy(nullable = true)
+            toSqlTypeHelper(Schema.createUnion(remainingUnionTypes.asJava), 
existingRecordNames)
+              .copy(nullable = true)
           }
         } else avroSchema.getTypes.asScala.map(_.getType) match {
           case Seq(t1) =>
-            toSqlType(avroSchema.getTypes.get(0))
+            toSqlTypeHelper(avroSchema.getTypes.get(0), existingRecordNames)
           case Seq(t1, t2) if Set(t1, t2) == Set(INT, LONG) =>
             SchemaType(LongType, nullable = false)
           case Seq(t1, t2) if Set(t1, t2) == Set(FLOAT, DOUBLE) =>
@@ -107,7 +119,7 @@ object SchemaConverters {
             // This is consistent with the behavior when converting between 
Avro and Parquet.
             val fields = avroSchema.getTypes.asScala.zipWithIndex.map {
               case (s, i) =>
-                val schemaType = toSqlType(s)
+                val schemaType = toSqlTypeHelper(s, existingRecordNames)
                 // All fields are nullable because only one of them is set at 
a time
                 StructField(s"member$i", schemaType.dataType, nullable = true)
             }

http://git-wip-us.apache.org/repos/asf/spark/blob/2eaf0587/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala
----------------------------------------------------------------------
diff --git 
a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala 
b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala
index 1e08f7b..4fea2cb 100644
--- a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala
+++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala
@@ -1309,4 +1309,69 @@ class AvroSuite extends QueryTest with SharedSQLContext 
with SQLTestUtils {
       checkCodec(df, path, "xz")
     }
   }
+
+  private def checkSchemaWithRecursiveLoop(avroSchema: String): Unit = {
+    val message = intercept[IncompatibleSchemaException] {
+      SchemaConverters.toSqlType(new Schema.Parser().parse(avroSchema))
+    }.getMessage
+
+    assert(message.contains("Found recursive reference in Avro schema"))
+  }
+
+  test("Detect recursive loop") {
+    checkSchemaWithRecursiveLoop("""
+      |{
+      |  "type": "record",
+      |  "name": "LongList",
+      |  "fields" : [
+      |    {"name": "value", "type": "long"},             // each element has 
a long
+      |    {"name": "next", "type": ["null", "LongList"]} // optional next 
element
+      |  ]
+      |}
+    """.stripMargin)
+
+    checkSchemaWithRecursiveLoop("""
+      |{
+      |  "type": "record",
+      |  "name": "LongList",
+      |  "fields": [
+      |    {
+      |      "name": "value",
+      |      "type": {
+      |        "type": "record",
+      |        "name": "foo",
+      |        "fields": [
+      |          {
+      |            "name": "parent",
+      |            "type": "LongList"
+      |          }
+      |        ]
+      |      }
+      |    }
+      |  ]
+      |}
+    """.stripMargin)
+
+    checkSchemaWithRecursiveLoop("""
+      |{
+      |  "type": "record",
+      |  "name": "LongList",
+      |  "fields" : [
+      |    {"name": "value", "type": "long"},
+      |    {"name": "array", "type": {"type": "array", "items": "LongList"}}
+      |  ]
+      |}
+    """.stripMargin)
+
+    checkSchemaWithRecursiveLoop("""
+      |{
+      |  "type": "record",
+      |  "name": "LongList",
+      |  "fields" : [
+      |    {"name": "value", "type": "long"},
+      |    {"name": "map", "type": {"type": "map", "values": "LongList"}}
+      |  ]
+      |}
+    """.stripMargin)
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to