bogao007 commented on code in PR #47425:
URL: https://github.com/apache/spark/pull/47425#discussion_r1690240057


##########
connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroOptions.scala:
##########
@@ -136,6 +136,13 @@ private[sql] class AvroOptions(
 
   val stableIdPrefixForUnionType: String = parameters
     .getOrElse(STABLE_ID_PREFIX_FOR_UNION_TYPE, "member_")
+
+  val recursiveFieldMaxDepth: Int =
+    parameters.get(RECURSIVE_FIELD_MAX_DEPTH).map(_.toInt).getOrElse(-1)
+
+  if (recursiveFieldMaxDepth > 15) {
+    throw new IllegalArgumentException(s"Valid range of 
$RECURSIVE_FIELD_MAX_DEPTH is 0 - 15.")

Review Comment:
   Can we follow the error class strategy to classify this error? You can refer 
to something like 
[this](https://github.com/apache/spark/blob/34e65a8e72513d0445b5ff9b251e3388625ad1ec/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala#L81-L87),
 but consider creating your own type.



##########
connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroOptions.scala:
##########
@@ -136,6 +136,13 @@ private[sql] class AvroOptions(
 
   val stableIdPrefixForUnionType: String = parameters
     .getOrElse(STABLE_ID_PREFIX_FOR_UNION_TYPE, "member_")
+
+  val recursiveFieldMaxDepth: Int =
+    parameters.get(RECURSIVE_FIELD_MAX_DEPTH).map(_.toInt).getOrElse(-1)
+
+  if (recursiveFieldMaxDepth > 15) {

Review Comment:
   Is 15 the max depth we allow? Can we using a constant value like 
`RECURSIVE_FIELD_DEPTH_LIMIT` to represent it?



##########
connector/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala:
##########
@@ -128,62 +137,113 @@ object SchemaConverters {
       case NULL => SchemaType(NullType, nullable = true)
 
       case RECORD =>
-        if (existingRecordNames.contains(avroSchema.getFullName)) {
+        val recursiveDepth: Int = 
existingRecordNames.getOrElse(avroSchema.getFullName, 0)
+        if (recursiveDepth > 0 && recursiveFieldMaxDepth <= 0) {
           throw new IncompatibleSchemaException(s"""
-            |Found recursive reference in Avro schema, which can not be 
processed by Spark:
-            |${avroSchema.toString(true)}
+            |Found recursive reference in Avro schema, which can not be 
processed by Spark by
+            | default: ${avroSchema.toString(true)}. Try setting the option 
`recursiveFieldMaxDepth`
+            | to 1 - 15. Going beyond 15 levels of recursion is not allowed.
           """.stripMargin)
-        }
-        val newRecordNames = existingRecordNames + avroSchema.getFullName
-        val fields = avroSchema.getFields.asScala.map { f =>
-          val schemaType = toSqlTypeHelper(
-            f.schema(),
-            newRecordNames,
-            useStableIdForUnionType,
-            stableIdPrefixForUnionType)
-          StructField(f.name, schemaType.dataType, schemaType.nullable)
-        }
+        } else if (recursiveDepth > 0 && recursiveDepth >= 
recursiveFieldMaxDepth) {
+          logInfo(
+            log"The field ${MDC(FIELD_NAME, avroSchema.getFullName)} of type " 
+
+              log"${MDC(FIELD_TYPE, avroSchema.getType.getName)} is dropped at 
recursive depth " +
+              log"${MDC(RECURSIVE_DEPTH, recursiveDepth)}."
+          )
+          null
+        } else {
+          val newRecordNames =
+            existingRecordNames + (avroSchema.getFullName -> (recursiveDepth + 
1))
+          val fields = avroSchema.getFields.asScala.map { f =>
+            val schemaType = toSqlTypeHelper(
+              f.schema(),
+              newRecordNames,
+              useStableIdForUnionType,
+              stableIdPrefixForUnionType,
+              recursiveFieldMaxDepth)
+            if (schemaType == null) {
+              null
+            }
+            else {
+              StructField(f.name, schemaType.dataType, schemaType.nullable)
+            }
+          }.filter(_ != null).toSeq
 
-        SchemaType(StructType(fields.toArray), nullable = false)
+          SchemaType(StructType(fields), nullable = false)
+        }
 
       case ARRAY =>
         val schemaType = toSqlTypeHelper(
           avroSchema.getElementType,
           existingRecordNames,
           useStableIdForUnionType,
-          stableIdPrefixForUnionType)
-        SchemaType(
-          ArrayType(schemaType.dataType, containsNull = schemaType.nullable),
-          nullable = false)
+          stableIdPrefixForUnionType,
+          recursiveFieldMaxDepth)
+        if (schemaType == null) {
+          logInfo(
+            log"Dropping ${MDC(FIELD_NAME, avroSchema.getFullName)} of type " +
+              log"${MDC(FIELD_TYPE, avroSchema.getType.getName)} as it does 
not have any " +
+              log"fields left likely due to recursive depth limit."
+          )
+          null
+        } else {
+          SchemaType(
+            ArrayType(schemaType.dataType, containsNull = schemaType.nullable),
+            nullable = false)
+        }
 
       case MAP =>
         val schemaType = toSqlTypeHelper(avroSchema.getValueType,
-          existingRecordNames, useStableIdForUnionType, 
stableIdPrefixForUnionType)
-        SchemaType(
-          MapType(StringType, schemaType.dataType, valueContainsNull = 
schemaType.nullable),
-          nullable = false)
+          existingRecordNames, useStableIdForUnionType, 
stableIdPrefixForUnionType,
+          recursiveFieldMaxDepth)
+        if (schemaType == null) {
+          logInfo(
+            log"Dropping ${MDC(FIELD_NAME, avroSchema.getFullName)} of type " +
+              log"${MDC(FIELD_TYPE, avroSchema.getType.getName)} as it does 
not have any " +
+              log"fields left likely due to recursive depth limit."
+          )
+          null
+        } else {
+          SchemaType(
+            MapType(StringType, schemaType.dataType, valueContainsNull = 
schemaType.nullable),
+            nullable = false)
+        }
 
       case UNION =>
         if (avroSchema.getTypes.asScala.exists(_.getType == NULL)) {
           // In case of a union with null, eliminate it and make a recursive 
call
           val remainingUnionTypes = AvroUtils.nonNullUnionBranches(avroSchema)
-          if (remainingUnionTypes.size == 1) {
-            toSqlTypeHelper(
-              remainingUnionTypes.head,
-              existingRecordNames,
-              useStableIdForUnionType,
-              stableIdPrefixForUnionType).copy(nullable = true)
+          val schemaType =
+            if (remainingUnionTypes.size == 1) {
+              toSqlTypeHelper(
+                remainingUnionTypes.head,
+                existingRecordNames,
+                useStableIdForUnionType,
+                stableIdPrefixForUnionType,
+                recursiveFieldMaxDepth)
+            } else {
+              toSqlTypeHelper(
+                Schema.createUnion(remainingUnionTypes.asJava),

Review Comment:
   nit: Maybe we could just use if-else clause to define the union types since 
all other parameters in `toSqlTypeHelper ` are the same?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to