rangadi commented on code in PR #38922: URL: https://github.com/apache/spark/pull/38922#discussion_r1043726930
########## connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/SchemaConverters.scala: ########## @@ -92,14 +109,38 @@ object SchemaConverters { MapType(keyType, valueType, valueContainsNull = false).defaultConcreteType, nullable = false)) case MESSAGE => - if (existingRecordNames.contains(fd.getFullName)) { - throw QueryCompilationErrors.foundRecursionInProtobufSchema(fd.toString()) + // User can set circularReferenceDepth of 0 or 1 or 2. + // Going beyond 3 levels of recursion is not allowed. + if (protobufOptions.circularReferenceType.equals("FIELD_TYPE")) { + if (existingRecordTypes.contains(fd.getType.name()) && + (protobufOptions.circularReferenceDepth < 0 || + protobufOptions.circularReferenceDepth >= 3)) { + throw QueryCompilationErrors.foundRecursionInProtobufSchema(fd.toString()) + } else if (existingRecordTypes.contains(fd.getType.name()) && Review Comment: name or full name? also what keeps track of the recursion depth? ########## connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufOptions.scala: ########## @@ -38,6 +38,12 @@ private[sql] class ProtobufOptions( val parseMode: ParseMode = parameters.get("mode").map(ParseMode.fromString).getOrElse(FailFastMode) + + val circularReferenceType: String = parameters.getOrElse("circularReferenceType", "FIELD_NAME") Review Comment: @SandishKumarHN @baganokodo2022 moving the discussion here (for threading). > Besides, can we also support a "CircularReferenceType" option with a enum value of [FIELD_NAME, FIELD_TYPE]. The reason is because navigation can go very deep before the same fully-qualified FIELD_NAME is encountered again. While FIELD_TYPE stops recursive navigation much faster. ... I didn't quite follow the motivation here. Could you give a concrete examples for the two difference cases? ########## connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/SchemaConverters.scala: ########## @@ -92,14 +109,38 @@ object SchemaConverters { MapType(keyType, valueType, valueContainsNull = false).defaultConcreteType, nullable = false)) case MESSAGE => - if (existingRecordNames.contains(fd.getFullName)) { - throw QueryCompilationErrors.foundRecursionInProtobufSchema(fd.toString()) + // User can set circularReferenceDepth of 0 or 1 or 2. + // Going beyond 3 levels of recursion is not allowed. Review Comment: Could you add a justification for this? ########## connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala: ########## @@ -26,11 +26,11 @@ import com.google.protobuf.{ByteString, DynamicMessage} import org.apache.spark.sql.{Column, QueryTest, Row} import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.functions.{lit, struct} -import org.apache.spark.sql.protobuf.protos.SimpleMessageProtos.SimpleMessageRepeated +import org.apache.spark.sql.protobuf.protos.SimpleMessageProtos.{EventRecursiveA, EventRecursiveB, OneOfEvent, OneOfEventWithRecursion, SimpleMessageRepeated} Review Comment: Are there tests for recursive fields? ########## connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala: ########## @@ -693,4 +693,178 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot errorClass = "CANNOT_CONSTRUCT_PROTOBUF_DESCRIPTOR", parameters = Map("descFilePath" -> testFileDescriptor)) } + + test("Unit test for Protobuf OneOf field") { Review Comment: Add a short description of the test at the top. It improves readability. What is this verifying? Remove "Unit test for", this is already a unit test :). ########## connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala: ########## @@ -693,4 +693,178 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot errorClass = "CANNOT_CONSTRUCT_PROTOBUF_DESCRIPTOR", parameters = Map("descFilePath" -> testFileDescriptor)) } + + test("Unit test for Protobuf OneOf field") { + val descriptor = ProtobufUtils.buildDescriptor(testFileDesc, "OneOfEvent") + val oneOfEvent = OneOfEvent.newBuilder() + .setKey("key") + .setCol1(123) + .setCol3(109202L) + .setCol2("col2value") + .addCol4("col4value").build() + + val df = Seq(oneOfEvent.toByteArray).toDF("value") + + val fromProtoDf = df.select( + functions.from_protobuf($"value", "OneOfEvent", testFileDesc) as 'sample) + val toDf = fromProtoDf.select( + functions.to_protobuf($"sample", "OneOfEvent", testFileDesc) as 'toProto) + val toFromDf = toDf.select( + functions.from_protobuf($"toProto", "OneOfEvent", testFileDesc) as 'fromToProto) + + checkAnswer(fromProtoDf, toFromDf) + + val actualFieldNames = fromProtoDf.select("sample.*").schema.fields.toSeq.map(f => f.name) + descriptor.getFields.asScala.map(f => { + assert(actualFieldNames.contains(f.getName)) + }) + + val eventFromSpark = OneOfEvent.parseFrom( + toDf.select("toProto").take(1).toSeq(0).getAs[Array[Byte]](0)) + + // OneOf field: the last set value(by order) will overwrite all previous ones. + assert(eventFromSpark.getCol2.equals("col2value")) + assert(eventFromSpark.getCol3 == 0) + + val expectedFields = descriptor.getFields.asScala.map(f => f.getName) + eventFromSpark.getDescriptorForType.getFields.asScala.map(f => { + assert(expectedFields.contains(f.getName)) + }) + + val schema = new StructType() + .add("sample", + new StructType() + .add("key", StringType) + .add("col_1", IntegerType) + .add("col_2", StringType) + .add("col_3", LongType) + .add("col_4", ArrayType(StringType)) + ) + + val data = Seq(Row(Row("key", 123, "col2value", 109202L, Seq("col4value")))) + val dataDf = spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + val dataDfToProto = dataDf.select( + functions.to_protobuf($"sample", "OneOfEvent", testFileDesc) as 'toProto) + val eventFromSparkSchema = OneOfEvent.parseFrom( + dataDfToProto.select("toProto").take(1).toSeq(0).getAs[Array[Byte]](0)) + assert(eventFromSparkSchema.getCol2.isEmpty) + assert(eventFromSparkSchema.getCol3 == 109202L) + eventFromSparkSchema.getDescriptorForType.getFields.asScala.map(f => { + assert(expectedFields.contains(f.getName)) + }) + } + + test("Unit tests for Protobuf OneOf field with circularReferenceDepth option") { + val descriptor = ProtobufUtils.buildDescriptor(testFileDesc, "OneOfEventWithRecursion") + + val recursiveANested = EventRecursiveA.newBuilder() + .setKey("keyNested3").build() + val oneOfEventNested = OneOfEventWithRecursion.newBuilder() + .setKey("keyNested2") + .setValue("valueNested2") + .setRecursiveA(recursiveANested).build() + val recursiveA = EventRecursiveA.newBuilder().setKey("recursiveAKey") + .setRecursiveA(oneOfEventNested).build() + val recursiveB = EventRecursiveB.newBuilder() + .setKey("recursiveBKey") + .setValue("recursiveBvalue").build() + val oneOfEventWithRecursion = OneOfEventWithRecursion.newBuilder() + .setKey("key1") + .setValue("value1") + .setRecursiveB(recursiveB) + .setRecursiveA(recursiveA).build() + + val df = Seq(oneOfEventWithRecursion.toByteArray).toDF("value") + + val options = new java.util.HashMap[String, String]() + options.put("circularReferenceDepth", "1") + + val fromProtoDf = df.select( + functions.from_protobuf($"value", + "OneOfEventWithRecursion", + testFileDesc, options) as 'sample) + + val toDf = fromProtoDf.select( + functions.to_protobuf($"sample", "OneOfEventWithRecursion", testFileDesc) as 'toProto) + val toFromDf = toDf.select( + functions.from_protobuf($"toProto", + "OneOfEventWithRecursion", + testFileDesc, + options) as 'fromToProto) + + checkAnswer(fromProtoDf, toFromDf) + + val actualFieldNames = fromProtoDf.select("sample.*").schema.fields.toSeq.map(f => f.name) + descriptor.getFields.asScala.map(f => { + assert(actualFieldNames.contains(f.getName)) + }) + + val eventFromSpark = OneOfEventWithRecursion.parseFrom( + toDf.select("toProto").take(1).toSeq(0).getAs[Array[Byte]](0)) + + // check circularReferenceDepth=1 value are present, but not circularReferenceDepth=2 + assert(eventFromSpark.getRecursiveA.getRecursiveA.getKey.equals("keyNested2")) + assert(eventFromSpark.getRecursiveA.getRecursiveA.getValue.equals("valueNested2")) + assert(eventFromSpark.getRecursiveA.getRecursiveA.getRecursiveA.getKey.isEmpty) + + val expectedFields = descriptor.getFields.asScala.map(f => f.getName) + eventFromSpark.getDescriptorForType.getFields.asScala.map(f => { + assert(expectedFields.contains(f.getName)) + }) + + val schema = StructType(Seq(StructField("sample", Review Comment: Btw, using `val schema = DataType.fromJson("json string") is lot more readable. Optional we could update many of these in follow up PRs. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org