rangadi commented on code in PR #38922:
URL: https://github.com/apache/spark/pull/38922#discussion_r1043726930


##########
connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/SchemaConverters.scala:
##########
@@ -92,14 +109,38 @@ object SchemaConverters {
             MapType(keyType, valueType, valueContainsNull = 
false).defaultConcreteType,
             nullable = false))
       case MESSAGE =>
-        if (existingRecordNames.contains(fd.getFullName)) {
-          throw 
QueryCompilationErrors.foundRecursionInProtobufSchema(fd.toString())
+        // User can set circularReferenceDepth of 0 or 1 or 2.
+        // Going beyond 3 levels of recursion is not allowed.
+        if (protobufOptions.circularReferenceType.equals("FIELD_TYPE")) {
+          if (existingRecordTypes.contains(fd.getType.name()) &&
+            (protobufOptions.circularReferenceDepth < 0 ||
+              protobufOptions.circularReferenceDepth >= 3)) {
+            throw 
QueryCompilationErrors.foundRecursionInProtobufSchema(fd.toString())
+          } else if (existingRecordTypes.contains(fd.getType.name()) &&

Review Comment:
   name or full name? 
   also what keeps track of the recursion depth? 



##########
connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufOptions.scala:
##########
@@ -38,6 +38,12 @@ private[sql] class ProtobufOptions(
 
   val parseMode: ParseMode =
     parameters.get("mode").map(ParseMode.fromString).getOrElse(FailFastMode)
+
+  val circularReferenceType: String = 
parameters.getOrElse("circularReferenceType", "FIELD_NAME")

Review Comment:
   @SandishKumarHN @baganokodo2022 moving the discussion here (for threading).
   
   > Besides, can we also support a "CircularReferenceType" option with a enum 
value of [FIELD_NAME, FIELD_TYPE]. The reason is because navigation can go very 
deep before the same fully-qualified FIELD_NAME is encountered again. While 
FIELD_TYPE stops recursive navigation much faster.  ...
   
   I didn't quite follow the motivation here. Could you give a concrete 
examples for the two difference cases?
   



##########
connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/SchemaConverters.scala:
##########
@@ -92,14 +109,38 @@ object SchemaConverters {
             MapType(keyType, valueType, valueContainsNull = 
false).defaultConcreteType,
             nullable = false))
       case MESSAGE =>
-        if (existingRecordNames.contains(fd.getFullName)) {
-          throw 
QueryCompilationErrors.foundRecursionInProtobufSchema(fd.toString())
+        // User can set circularReferenceDepth of 0 or 1 or 2.
+        // Going beyond 3 levels of recursion is not allowed.

Review Comment:
   Could you add a justification for this?



##########
connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala:
##########
@@ -26,11 +26,11 @@ import com.google.protobuf.{ByteString, DynamicMessage}
 import org.apache.spark.sql.{Column, QueryTest, Row}
 import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.functions.{lit, struct}
-import 
org.apache.spark.sql.protobuf.protos.SimpleMessageProtos.SimpleMessageRepeated
+import 
org.apache.spark.sql.protobuf.protos.SimpleMessageProtos.{EventRecursiveA, 
EventRecursiveB, OneOfEvent, OneOfEventWithRecursion, SimpleMessageRepeated}

Review Comment:
   Are there tests for recursive fields? 



##########
connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala:
##########
@@ -693,4 +693,178 @@ class ProtobufFunctionsSuite extends QueryTest with 
SharedSparkSession with Prot
       errorClass = "CANNOT_CONSTRUCT_PROTOBUF_DESCRIPTOR",
       parameters = Map("descFilePath" -> testFileDescriptor))
   }
+
+  test("Unit test for Protobuf OneOf field") {

Review Comment:
   Add a short description of the test at the top. It improves readability. 
What is this verifying? 
   
   Remove "Unit test for", this is already a unit test :). 



##########
connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala:
##########
@@ -693,4 +693,178 @@ class ProtobufFunctionsSuite extends QueryTest with 
SharedSparkSession with Prot
       errorClass = "CANNOT_CONSTRUCT_PROTOBUF_DESCRIPTOR",
       parameters = Map("descFilePath" -> testFileDescriptor))
   }
+
+  test("Unit test for Protobuf OneOf field") {
+    val descriptor = ProtobufUtils.buildDescriptor(testFileDesc, "OneOfEvent")
+    val oneOfEvent = OneOfEvent.newBuilder()
+      .setKey("key")
+      .setCol1(123)
+      .setCol3(109202L)
+      .setCol2("col2value")
+      .addCol4("col4value").build()
+
+    val df = Seq(oneOfEvent.toByteArray).toDF("value")
+
+    val fromProtoDf = df.select(
+      functions.from_protobuf($"value", "OneOfEvent", testFileDesc) as 'sample)
+    val toDf = fromProtoDf.select(
+      functions.to_protobuf($"sample", "OneOfEvent", testFileDesc) as 'toProto)
+    val toFromDf = toDf.select(
+      functions.from_protobuf($"toProto", "OneOfEvent", testFileDesc) as 
'fromToProto)
+
+    checkAnswer(fromProtoDf, toFromDf)
+
+    val actualFieldNames = 
fromProtoDf.select("sample.*").schema.fields.toSeq.map(f => f.name)
+    descriptor.getFields.asScala.map(f => {
+      assert(actualFieldNames.contains(f.getName))
+    })
+
+    val eventFromSpark = OneOfEvent.parseFrom(
+      toDf.select("toProto").take(1).toSeq(0).getAs[Array[Byte]](0))
+
+    // OneOf field: the last set value(by order) will overwrite all previous 
ones.
+    assert(eventFromSpark.getCol2.equals("col2value"))
+    assert(eventFromSpark.getCol3 == 0)
+
+    val expectedFields = descriptor.getFields.asScala.map(f => f.getName)
+    eventFromSpark.getDescriptorForType.getFields.asScala.map(f => {
+      assert(expectedFields.contains(f.getName))
+    })
+
+    val schema = new StructType()
+      .add("sample",
+        new StructType()
+          .add("key", StringType)
+          .add("col_1", IntegerType)
+          .add("col_2", StringType)
+          .add("col_3", LongType)
+          .add("col_4", ArrayType(StringType))
+      )
+
+    val data = Seq(Row(Row("key", 123, "col2value", 109202L, 
Seq("col4value"))))
+    val dataDf = spark.createDataFrame(spark.sparkContext.parallelize(data), 
schema)
+    val dataDfToProto = dataDf.select(
+      functions.to_protobuf($"sample", "OneOfEvent", testFileDesc) as 'toProto)
+    val eventFromSparkSchema = OneOfEvent.parseFrom(
+      dataDfToProto.select("toProto").take(1).toSeq(0).getAs[Array[Byte]](0))
+    assert(eventFromSparkSchema.getCol2.isEmpty)
+    assert(eventFromSparkSchema.getCol3 == 109202L)
+    eventFromSparkSchema.getDescriptorForType.getFields.asScala.map(f => {
+      assert(expectedFields.contains(f.getName))
+    })
+  }
+
+  test("Unit tests for Protobuf OneOf field with circularReferenceDepth 
option") {
+    val descriptor = ProtobufUtils.buildDescriptor(testFileDesc, 
"OneOfEventWithRecursion")
+
+    val recursiveANested = EventRecursiveA.newBuilder()
+      .setKey("keyNested3").build()
+    val oneOfEventNested = OneOfEventWithRecursion.newBuilder()
+      .setKey("keyNested2")
+      .setValue("valueNested2")
+      .setRecursiveA(recursiveANested).build()
+    val recursiveA = EventRecursiveA.newBuilder().setKey("recursiveAKey")
+      .setRecursiveA(oneOfEventNested).build()
+    val recursiveB = EventRecursiveB.newBuilder()
+      .setKey("recursiveBKey")
+      .setValue("recursiveBvalue").build()
+    val oneOfEventWithRecursion = OneOfEventWithRecursion.newBuilder()
+      .setKey("key1")
+      .setValue("value1")
+      .setRecursiveB(recursiveB)
+      .setRecursiveA(recursiveA).build()
+
+    val df = Seq(oneOfEventWithRecursion.toByteArray).toDF("value")
+
+    val options = new java.util.HashMap[String, String]()
+    options.put("circularReferenceDepth", "1")
+
+    val fromProtoDf = df.select(
+      functions.from_protobuf($"value",
+        "OneOfEventWithRecursion",
+        testFileDesc, options) as 'sample)
+
+    val toDf = fromProtoDf.select(
+      functions.to_protobuf($"sample", "OneOfEventWithRecursion", 
testFileDesc) as 'toProto)
+    val toFromDf = toDf.select(
+      functions.from_protobuf($"toProto",
+        "OneOfEventWithRecursion",
+        testFileDesc,
+        options) as 'fromToProto)
+
+    checkAnswer(fromProtoDf, toFromDf)
+
+    val actualFieldNames = 
fromProtoDf.select("sample.*").schema.fields.toSeq.map(f => f.name)
+    descriptor.getFields.asScala.map(f => {
+      assert(actualFieldNames.contains(f.getName))
+    })
+
+    val eventFromSpark = OneOfEventWithRecursion.parseFrom(
+      toDf.select("toProto").take(1).toSeq(0).getAs[Array[Byte]](0))
+
+    // check circularReferenceDepth=1 value are present, but not 
circularReferenceDepth=2
+    
assert(eventFromSpark.getRecursiveA.getRecursiveA.getKey.equals("keyNested2"))
+    
assert(eventFromSpark.getRecursiveA.getRecursiveA.getValue.equals("valueNested2"))
+    
assert(eventFromSpark.getRecursiveA.getRecursiveA.getRecursiveA.getKey.isEmpty)
+
+    val expectedFields = descriptor.getFields.asScala.map(f => f.getName)
+    eventFromSpark.getDescriptorForType.getFields.asScala.map(f => {
+      assert(expectedFields.contains(f.getName))
+    })
+
+    val schema = StructType(Seq(StructField("sample",

Review Comment:
   Btw, using `val schema = DataType.fromJson("json string") is lot more 
readable. 
   Optional we could update many of these in follow up PRs. 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to