baganokodo2022 commented on code in PR #38922:
URL: https://github.com/apache/spark/pull/38922#discussion_r1043901516
##
connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/SchemaConverters.scala:
##
@@ -40,19 +40,26 @@ object SchemaConverters {
*
* @since 3.4.0
*/
- def toSqlType(descriptor: Descriptor): SchemaType = {
-toSqlTypeHelper(descriptor)
+ def toSqlType(
+ descriptor: Descriptor,
+ protobufOptions: ProtobufOptions = ProtobufOptions(Map.empty)):
SchemaType = {
+toSqlTypeHelper(descriptor, protobufOptions)
}
- def toSqlTypeHelper(descriptor: Descriptor): SchemaType =
ScalaReflectionLock.synchronized {
+ def toSqlTypeHelper(
+ descriptor: Descriptor,
+ protobufOptions: ProtobufOptions): SchemaType =
ScalaReflectionLock.synchronized {
SchemaType(
- StructType(descriptor.getFields.asScala.flatMap(structFieldFor(_,
Set.empty)).toArray),
+ StructType(descriptor.getFields.asScala.flatMap(
+structFieldFor(_, Map.empty, Map.empty, protobufOptions:
ProtobufOptions)).toArray),
nullable = true)
}
def structFieldFor(
fd: FieldDescriptor,
- existingRecordNames: Set[String]): Option[StructField] = {
+ existingRecordNames: Map[String, Int],
+ existingRecordTypes: Map[String, Int],
Review Comment:
@SandishKumarHN since it is going to be either `FIELD_NAME` or `FIELD_TYPE`,
do we need keep both 2 Maps?
##
connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufOptions.scala:
##
@@ -38,6 +38,12 @@ private[sql] class ProtobufOptions(
val parseMode: ParseMode =
parameters.get("mode").map(ParseMode.fromString).getOrElse(FailFastMode)
+
+ val circularReferenceType: String =
parameters.getOrElse("circularReferenceType", "FIELD_NAME")
Review Comment:
>
Yes @SandishKumarHN you are right. That is discovered from a very complex
Proto schema shared across many micro services.
##
connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/SchemaConverters.scala:
##
@@ -92,14 +109,38 @@ object SchemaConverters {
MapType(keyType, valueType, valueContainsNull =
false).defaultConcreteType,
nullable = false))
case MESSAGE =>
-if (existingRecordNames.contains(fd.getFullName)) {
- throw
QueryCompilationErrors.foundRecursionInProtobufSchema(fd.toString())
+// User can set circularReferenceDepth of 0 or 1 or 2.
+// Going beyond 3 levels of recursion is not allowed.
+if (protobufOptions.circularReferenceType.equals("FIELD_TYPE")) {
+ if (existingRecordTypes.contains(fd.getType.name()) &&
+(protobufOptions.circularReferenceDepth < 0 ||
+ protobufOptions.circularReferenceDepth >= 3)) {
+throw
QueryCompilationErrors.foundRecursionInProtobufSchema(fd.toString())
+ } else if (existingRecordTypes.contains(fd.getType.name()) &&
Review Comment:
@SandishKumarHN and @rangadi , should we error out on `-1` the default value
unless users specifically override?
0 -> drop all recursed fields once encountered
1 -> allowed the same field name (type) to be entered twice.
2 -> allowed the same field name (type) to be entered 3 timce.
thoughts?
##
connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/SchemaConverters.scala:
##
@@ -92,14 +109,38 @@ object SchemaConverters {
MapType(keyType, valueType, valueContainsNull =
false).defaultConcreteType,
nullable = false))
case MESSAGE =>
-if (existingRecordNames.contains(fd.getFullName)) {
- throw
QueryCompilationErrors.foundRecursionInProtobufSchema(fd.toString())
+// User can set circularReferenceDepth of 0 or 1 or 2.
+// Going beyond 3 levels of recursion is not allowed.
+if (protobufOptions.circularReferenceType.equals("FIELD_TYPE")) {
+ if (existingRecordTypes.contains(fd.getType.name()) &&
+(protobufOptions.circularReferenceDepth < 0 ||
+ protobufOptions.circularReferenceDepth >= 3)) {
+throw
QueryCompilationErrors.foundRecursionInProtobufSchema(fd.toString())
+ } else if (existingRecordTypes.contains(fd.getType.name()) &&
Review Comment:
In my back-ported branch,
```
val recordName = circularReferenceType match {
case CircularReferenceTypes.FIELD_NAME =>
fd.getFullName
case CircularReferenceTypes.FIELD_TYPE =>
fd.getFullName().substring(0, fd.getFullName().lastIndexOf("."))
}
if (circularReferenceTolerance < 0 &&
existingRecordNames(recordName) > 0) {
// no tolerance on circular reference
logError(s"circular reference in protobuf schema detected [no
tolerance] - ${recordName}")
throw new