This is an automated email from the ASF dual-hosted git repository.
ashrigondekar pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new dd5ce947d808 [SPARK-55062][PROTOBUF] Support proto2 extensions in
protobuf functions
dd5ce947d808 is described below
commit dd5ce947d80855b3793e5f33e7cf51c593d897e6
Author: David Young <[email protected]>
AuthorDate: Mon Feb 23 10:24:10 2026 -0800
[SPARK-55062][PROTOBUF] Support proto2 extensions in protobuf functions
### What changes were proposed in this pull request?
This PR adds support for proto2 extensions to `from_protobuf` and
`to_protobuf` (when file descriptor set is provided, as Java classes do not
contain enough information to support extensions).
This is done by building an ExtensionRegistry and a map from descriptor
name to its extensions. The registry is used during construction of the
DynamicMessage to provide the Protobuf library with visibility of the
extensions. The index is plumbed through the various helper classes for use in
schema conversion and serde.
This new functionality is gated behind the Spark config property
`spark.sql.function.protobufExtensions.enabled`.
### Why are the changes needed?
Proto2 extensions are a valid, if somewhat uncommon, feature of Protobuf,
and it therefore makes sense to incorporate them into the schema when provided
so as to not confuse the user.
### Does this PR introduce _any_ user-facing change?
Yes. Previously, extension fields would be dropped by both `from_protobuf`
and `to_protobuf`. Now, they are retained. This can be demonstrated with the
minimal example below. See the unit tests for more examples.
```proto
message Person {
int32 id = 1;
extensions 100 to 200;
}
extend Person {
int32 age = 100;
}
```
### How was this patch tested?
Unit tests were added for the new behavior, including basic behavior,
extending nested messages, and extensions defined in separate files.
### Was this patch authored or co-authored using generative AI tooling?
Initial draft authored with Claude Code.
Generated-by: claude-4.5-opus
Closes #53828 from dichlorodiphen/proto-rebase.
Authored-by: David Young <[email protected]>
Signed-off-by: Anish Shrigondekar <[email protected]>
---
.../sql/protobuf/CatalystDataToProtobuf.scala | 9 +-
.../sql/protobuf/ProtobufDataToCatalyst.scala | 31 +-
.../spark/sql/protobuf/ProtobufDeserializer.scala | 76 ++-
.../spark/sql/protobuf/ProtobufSerializer.scala | 53 +-
.../spark/sql/protobuf/utils/ProtobufUtils.scala | 203 ++++++--
.../sql/protobuf/utils/SchemaConverters.scala | 57 ++-
.../test/resources/protobuf/proto2_messages.proto | 69 +++
.../resources/protobuf/proto2_messages_ext.proto | 32 ++
.../ProtobufCatalystDataConversionSuite.scala | 15 +-
.../sql/protobuf/ProtobufExtensionsSuite.scala | 538 +++++++++++++++++++++
.../sql/protobuf/ProtobufFunctionsSuite.scala | 64 ++-
.../spark/sql/protobuf/ProtobufSerdeSuite.scala | 35 +-
.../org/apache/spark/sql/internal/SQLConf.scala | 9 +
.../streaming/checkpointing/OffsetSeq.scala | 5 +-
.../ProtobufExtensionsSupportOffsetLogSuite.scala | 110 +++++
15 files changed, 1155 insertions(+), 151 deletions(-)
diff --git
a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/CatalystDataToProtobuf.scala
b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/CatalystDataToProtobuf.scala
index 0564eee1602a..9ed056690d26 100644
---
a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/CatalystDataToProtobuf.scala
+++
b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/CatalystDataToProtobuf.scala
@@ -35,11 +35,16 @@ private[sql] case class CatalystDataToProtobuf(
override def dataType: DataType = BinaryType
- @transient private lazy val protoDescriptor =
+ @transient private lazy val descriptorWithExtensions =
ProtobufUtils.buildDescriptor(messageName, binaryFileDescriptorSet)
+ @transient private lazy val protoDescriptor =
descriptorWithExtensions.descriptor
+
+ @transient private lazy val fullNamesToExtensions =
+ descriptorWithExtensions.fullNamesToExtensions
+
@transient private lazy val serializer =
- new ProtobufSerializer(child.dataType, protoDescriptor, child.nullable)
+ new ProtobufSerializer(child.dataType, protoDescriptor, child.nullable,
fullNamesToExtensions)
override def nullSafeEval(input: Any): Any = {
val dynamicMessage =
serializer.serialize(input).asInstanceOf[DynamicMessage]
diff --git
a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDataToCatalyst.scala
b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDataToCatalyst.scala
index b3225d61eb01..864eb93a382b 100644
---
a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDataToCatalyst.scala
+++
b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDataToCatalyst.scala
@@ -40,17 +40,24 @@ private[sql] case class ProtobufDataToCatalyst(
override def inputTypes: Seq[AbstractDataType] = Seq(BinaryType)
override lazy val dataType: DataType =
- SchemaConverters.toSqlType(messageDescriptor, protobufOptions).dataType
+ SchemaConverters.toSqlType(messageDescriptor, fullNamesToExtensions,
protobufOptions).dataType
override def nullable: Boolean = true
private lazy val protobufOptions = ProtobufOptions(options)
- @transient private lazy val messageDescriptor =
+ @transient private lazy val descriptorWithExtensions =
ProtobufUtils.buildDescriptor(messageName, binaryFileDescriptorSet)
- @transient private lazy val fieldsNumbers =
- messageDescriptor.getFields.asScala.map(f => f.getNumber).toSet
+ @transient private lazy val messageDescriptor =
descriptorWithExtensions.descriptor
+
+ @transient private lazy val extensionRegistry =
descriptorWithExtensions.extensionRegistry
+
+ @transient private lazy val fullNamesToExtensions =
+ descriptorWithExtensions.fullNamesToExtensions
+
+ @transient private lazy val regularFieldNumbers =
+ messageDescriptor.getFields.asScala.map(_.getNumber).toSet
@transient private lazy val deserializer = {
val typeRegistry = binaryFileDescriptorSet match {
@@ -65,8 +72,8 @@ private[sql] case class ProtobufDataToCatalyst(
dataType,
typeRegistry = typeRegistry,
emitDefaultValues = protobufOptions.emitDefaultValues,
- enumsAsInts = protobufOptions.enumsAsInts
- )
+ enumsAsInts = protobufOptions.enumsAsInts,
+ fullNamesToExtensions = fullNamesToExtensions)
}
@transient private var result: DynamicMessage = _
@@ -92,14 +99,18 @@ private[sql] case class ProtobufDataToCatalyst(
override def nullSafeEval(input: Any): Any = {
val binary = input.asInstanceOf[Array[Byte]]
try {
- result = DynamicMessage.parseFrom(messageDescriptor, binary)
+ result = DynamicMessage.parseFrom(messageDescriptor, binary,
extensionRegistry)
// If the Java class is available, it is likely more efficient to parse
with it than using
// DynamicMessage. Can consider it in the future if parsing overhead is
noticeable.
-
result.getUnknownFields.asMap().keySet().asScala.find(fieldsNumbers.contains(_))
match {
+ result.getUnknownFields
+ .asMap()
+ .keySet()
+ .asScala
+ .find(regularFieldNumbers.contains(_)) match {
case Some(number) =>
- // Unknown fields contain a field with same number as a known field.
Must be due to
- // mismatch of schema between writer and reader here.
+ // Unknown fields contain a field with same number as a known
regular field. Must be due
+ // to mismatch of schema between writer and reader here.
throw QueryCompilationErrors.protobufFieldTypeMismatchError(
messageDescriptor.getFields.get(number).toString)
case None =>
diff --git
a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDeserializer.scala
b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDeserializer.scala
index fa6567ae2aa5..9f2bafb8a180 100644
---
a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDeserializer.scala
+++
b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDeserializer.scala
@@ -40,7 +40,8 @@ private[sql] class ProtobufDeserializer(
filters: StructFilters = new NoopFilters,
typeRegistry: TypeRegistry = TypeRegistry.getEmptyTypeRegistry,
emitDefaultValues: Boolean = false,
- enumsAsInts: Boolean = false) {
+ enumsAsInts: Boolean = false,
+ fullNamesToExtensions: Map[String, Seq[FieldDescriptor]] = Map.empty) {
def this(rootDescriptor: Descriptor, rootCatalystType: DataType) = {
this(
@@ -59,7 +60,8 @@ private[sql] class ProtobufDeserializer(
val resultRow = new SpecificInternalRow(st.map(_.dataType))
val fieldUpdater = new RowUpdater(resultRow)
val applyFilters = filters.skipRow(resultRow, _)
- val writer = getRecordWriter(rootDescriptor, st, Nil, Nil,
applyFilters)
+ val writer =
+ getRecordWriter(rootDescriptor, st, Nil, Nil, applyFilters,
fullNamesToExtensions)
(data: Any) => {
val record = data.asInstanceOf[DynamicMessage]
val skipRow = writer(fieldUpdater, record)
@@ -97,11 +99,18 @@ private[sql] class ProtobufDeserializer(
protoPath: Seq[String],
catalystPath: Seq[String],
elementType: DataType,
- containsNull: Boolean): (CatalystDataUpdater, Int, Any) => Unit = {
+ containsNull: Boolean,
+ protoFullNamesToExtensions: Map[String, Seq[FieldDescriptor]] =
Map.empty)
+ : (CatalystDataUpdater, Int, Any) => Unit = {
val protoElementPath = protoPath :+ "element"
val elementWriter =
- newWriter(protoField, elementType, protoElementPath, catalystPath :+
"element")
+ newWriter(
+ protoField,
+ elementType,
+ protoElementPath,
+ catalystPath :+ "element",
+ protoFullNamesToExtensions)
(updater, ordinal, value) =>
val collection = value.asInstanceOf[java.util.Collection[Any]]
val result = createArrayData(elementType, collection.size())
@@ -133,12 +142,24 @@ private[sql] class ProtobufDeserializer(
catalystPath: Seq[String],
keyType: DataType,
valueType: DataType,
- valueContainsNull: Boolean): (CatalystDataUpdater, Int, Any) => Unit = {
+ valueContainsNull: Boolean,
+ protoFullNamesToExtensions: Map[String, Seq[FieldDescriptor]] =
Map.empty)
+ : (CatalystDataUpdater, Int, Any) => Unit = {
val keyField = protoType.getMessageType.getFields.get(0)
val valueField = protoType.getMessageType.getFields.get(1)
- val keyWriter = newWriter(keyField, keyType, protoPath :+ "key",
catalystPath :+ "key")
+ val keyWriter = newWriter(
+ keyField,
+ keyType,
+ protoPath :+ "key",
+ catalystPath :+ "key",
+ protoFullNamesToExtensions)
val valueWriter =
- newWriter(valueField, valueType, protoPath :+ "value", catalystPath :+
"value")
+ newWriter(
+ valueField,
+ valueType,
+ protoPath :+ "value",
+ catalystPath :+ "value",
+ protoFullNamesToExtensions)
(updater, ordinal, value) =>
if (value != null) {
val messageList =
value.asInstanceOf[java.util.List[com.google.protobuf.Message]]
@@ -174,7 +195,9 @@ private[sql] class ProtobufDeserializer(
protoType: FieldDescriptor,
catalystType: DataType,
protoPath: Seq[String],
- catalystPath: Seq[String]): (CatalystDataUpdater, Int, Any) => Unit = {
+ catalystPath: Seq[String],
+ protoFullNamesToExtensions: Map[String, Seq[FieldDescriptor]] =
Map.empty)
+ : (CatalystDataUpdater, Int, Any) => Unit = {
(protoType.getJavaType, catalystType) match {
@@ -201,7 +224,13 @@ private[sql] class ProtobufDeserializer(
case (
MESSAGE | BOOLEAN | INT | FLOAT | DOUBLE | LONG | STRING | ENUM |
BYTE_STRING,
ArrayType(dataType: DataType, containsNull)) if protoType.isRepeated =>
- newArrayWriter(protoType, protoPath, catalystPath, dataType,
containsNull)
+ newArrayWriter(
+ protoType,
+ protoPath,
+ catalystPath,
+ dataType,
+ containsNull,
+ protoFullNamesToExtensions)
case (LONG, LongType) =>
(updater, ordinal, value) => updater.setLong(ordinal,
value.asInstanceOf[Long])
@@ -236,7 +265,14 @@ private[sql] class ProtobufDeserializer(
updater.set(ordinal, byte_array)
case (MESSAGE, MapType(keyType, valueType, valueContainsNull)) =>
- newMapWriter(protoType, protoPath, catalystPath, keyType, valueType,
valueContainsNull)
+ newMapWriter(
+ protoType,
+ protoPath,
+ catalystPath,
+ keyType,
+ valueType,
+ valueContainsNull,
+ protoFullNamesToExtensions)
case (MESSAGE, TimestampType) =>
(updater, ordinal, value) =>
@@ -368,7 +404,8 @@ private[sql] class ProtobufDeserializer(
st,
protoPath,
catalystPath,
- applyFilters = _ => false)
+ applyFilters = _ => false,
+ protoFullNamesToExtensions)
(updater, ordinal, value) =>
val row = new SpecificInternalRow(st)
writeRecord(new RowUpdater(row), value.asInstanceOf[DynamicMessage])
@@ -399,10 +436,20 @@ private[sql] class ProtobufDeserializer(
catalystType: StructType,
protoPath: Seq[String],
catalystPath: Seq[String],
- applyFilters: Int => Boolean): (CatalystDataUpdater, DynamicMessage) =>
Boolean = {
+ applyFilters: Int => Boolean,
+ protoFullNamesToExtensions: Map[String, Seq[FieldDescriptor]] =
Map.empty)
+ : (CatalystDataUpdater, DynamicMessage) => Boolean = {
+ // Get extension fields for this specific message type
+ val protoExtensionFields =
+ protoFullNamesToExtensions.getOrElse(protoType.getFullName, Seq.empty)
val protoSchemaHelper =
- new ProtobufUtils.ProtoSchemaHelper(protoType, catalystType, protoPath,
catalystPath)
+ new ProtobufUtils.ProtoSchemaHelper(
+ protoType,
+ catalystType,
+ protoPath,
+ catalystPath,
+ protoExtensionFields)
// TODO revisit validation of protobuf-catalyst fields.
// protoSchemaHelper.validateNoExtraCatalystFields(ignoreNullable = true)
@@ -414,7 +461,8 @@ private[sql] class ProtobufDeserializer(
protoField,
catalystField.dataType,
protoPath :+ protoField.getName,
- catalystPath :+ catalystField.name)
+ catalystPath :+ catalystField.name,
+ protoFullNamesToExtensions)
val fieldWriter = (fieldUpdater: CatalystDataUpdater, value: Any) => {
if (value == null) {
fieldUpdater.setNullAt(ordinal)
diff --git
a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufSerializer.scala
b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufSerializer.scala
index 65e8cce0d056..cd638f3ac12a 100644
---
a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufSerializer.scala
+++
b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufSerializer.scala
@@ -39,7 +39,8 @@ import org.apache.spark.sql.types._
private[sql] class ProtobufSerializer(
rootCatalystType: DataType,
rootDescriptor: Descriptor,
- nullable: Boolean)
+ nullable: Boolean,
+ fullNamesToExtensions: Map[String, Seq[FieldDescriptor]] = Map.empty)
extends Logging {
def serialize(catalystData: Any): Any = {
@@ -54,7 +55,8 @@ private[sql] class ProtobufSerializer(
try {
rootCatalystType match {
case st: StructType =>
- newStructConverter(st, rootDescriptor, Nil, Nil).asInstanceOf[Any
=> Any]
+ newStructConverter(st, rootDescriptor, Nil, Nil,
fullNamesToExtensions)
+ .asInstanceOf[Any => Any]
}
} catch {
case ise: AnalysisException =>
@@ -80,7 +82,8 @@ private[sql] class ProtobufSerializer(
catalystType: DataType,
fieldDescriptor: FieldDescriptor,
catalystPath: Seq[String],
- protoPath: Seq[String]): Converter = {
+ protoPath: Seq[String],
+ protoFullNamesToExtensions: Map[String, Seq[FieldDescriptor]]):
Converter = {
(catalystType, fieldDescriptor.getJavaType) match {
case (NullType, _) =>
(getter, ordinal) => null
@@ -157,7 +160,12 @@ private[sql] class ProtobufSerializer(
case (ArrayType(et, containsNull), _) =>
val elementConverter =
- newConverter(et, fieldDescriptor, catalystPath :+ "element",
protoPath :+ "element")
+ newConverter(
+ et,
+ fieldDescriptor,
+ catalystPath :+ "element",
+ protoPath :+ "element",
+ protoFullNamesToExtensions)
(getter, ordinal) => {
val arrayData = getter.getArray(ordinal)
val len = arrayData.numElements()
@@ -224,7 +232,12 @@ private[sql] class ProtobufSerializer(
case (st: StructType, MESSAGE) =>
val structConverter =
- newStructConverter(st, fieldDescriptor.getMessageType, catalystPath,
protoPath)
+ newStructConverter(
+ st,
+ fieldDescriptor.getMessageType,
+ catalystPath,
+ protoPath,
+ protoFullNamesToExtensions)
val numFields = st.length
(getter, ordinal) => structConverter(getter.getStruct(ordinal,
numFields))
@@ -240,9 +253,19 @@ private[sql] class ProtobufSerializer(
}
}
- val keyConverter = newConverter(kt, keyField, catalystPath :+ "key",
protoPath :+ "key")
+ val keyConverter = newConverter(
+ kt,
+ keyField,
+ catalystPath :+ "key",
+ protoPath :+ "key",
+ protoFullNamesToExtensions)
val valueConverter =
- newConverter(vt, valueField, catalystPath :+ "value", protoPath :+
"value")
+ newConverter(
+ vt,
+ valueField,
+ catalystPath :+ "value",
+ protoPath :+ "value",
+ protoFullNamesToExtensions)
(getter, ordinal) =>
val mapData = getter.getMap(ordinal)
@@ -301,10 +324,19 @@ private[sql] class ProtobufSerializer(
catalystStruct: StructType,
descriptor: Descriptor,
catalystPath: Seq[String],
- protoPath: Seq[String]): InternalRow => DynamicMessage = {
+ protoPath: Seq[String],
+ protoFullNamesToExtensions: Map[String, Seq[FieldDescriptor]])
+ : InternalRow => DynamicMessage = {
+ val protoExtensionFields =
+ protoFullNamesToExtensions.getOrElse(descriptor.getFullName, Seq.empty)
val protoSchemaHelper =
- new ProtobufUtils.ProtoSchemaHelper(descriptor, catalystStruct,
protoPath, catalystPath)
+ new ProtobufUtils.ProtoSchemaHelper(
+ descriptor,
+ catalystStruct,
+ protoPath,
+ catalystPath,
+ protoExtensionFields)
protoSchemaHelper.validateNoExtraCatalystFields(ignoreNullable = false)
protoSchemaHelper.validateNoExtraRequiredProtoFields()
@@ -315,7 +347,8 @@ private[sql] class ProtobufSerializer(
catalystField.dataType,
protoField,
catalystPath :+ catalystField.name,
- protoPath :+ protoField.getName)
+ protoPath :+ protoField.getName,
+ protoFullNamesToExtensions)
(protoField, converter)
}
.toArray
diff --git
a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufUtils.scala
b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufUtils.scala
index 3d7bba7a82e8..47898f73d016 100644
---
a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufUtils.scala
+++
b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufUtils.scala
@@ -19,11 +19,14 @@ package org.apache.spark.sql.protobuf.utils
import java.util.Locale
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
import scala.jdk.CollectionConverters._
-import com.google.protobuf.{DescriptorProtos, Descriptors,
InvalidProtocolBufferException, Message}
+import com.google.protobuf.{DescriptorProtos, Descriptors, DynamicMessage,
ExtensionRegistry, InvalidProtocolBufferException, Message}
import com.google.protobuf.DescriptorProtos.{FileDescriptorProto,
FileDescriptorSet}
import com.google.protobuf.Descriptors.{Descriptor, FieldDescriptor}
+import com.google.protobuf.Descriptors.FieldDescriptor.JavaType
import com.google.protobuf.TypeRegistry
import org.apache.spark.internal.Logging
@@ -40,6 +43,21 @@ private[sql] object ProtobufUtils extends Logging {
catalystPosition: Int,
fieldDescriptor: FieldDescriptor)
+ /**
+ * Container for a Protobuf descriptor and associated extensions.
+ *
+ * @param descriptor
+ * The descriptor for the top-level message.
+ * @param extensionRegistry
+ * An extension registry populated with all discovered extensions.
+ * @param fullNamesToExtensions
+ * A map from message full names to their extension fields.
+ */
+ private[sql] case class DescriptorWithExtensions(
+ descriptor: Descriptor,
+ extensionRegistry: ExtensionRegistry,
+ fullNamesToExtensions: Map[String, Seq[FieldDescriptor]])
+
/**
* Helper class to perform field lookup/matching on Protobuf schemas.
*
@@ -55,20 +73,25 @@ private[sql] object ProtobufUtils extends Logging {
* The seq of parent field names leading to `protoSchema`.
* @param catalystPath
* The seq of parent field names leading to `catalystSchema`.
+ * @param extensionFields
+ * Optional sequence of extension fields to include in field matching.
*/
class ProtoSchemaHelper(
descriptor: Descriptor,
catalystSchema: StructType,
protoPath: Seq[String],
- catalystPath: Seq[String]) {
+ catalystPath: Seq[String],
+ extensionFields: Seq[FieldDescriptor] = Seq.empty) {
if (descriptor.getName == null) {
throw QueryCompilationErrors.unknownProtobufMessageTypeError(
descriptor.getName,
descriptor.getContainingType().getName)
}
- private[this] val protoFieldArray = descriptor.getFields.asScala.toArray
- private[this] val fieldMap = descriptor.getFields.asScala
+ // Combine regular fields with extension fields
+ private[this] val protoFieldArray =
+ (descriptor.getFields.asScala ++ extensionFields).toArray
+ private[this] val fieldMap = protoFieldArray
.groupBy(_.getName.toLowerCase(Locale.ROOT))
.transform((_, v) => v.toSeq) // toSeq needed for scala 2.13
@@ -135,27 +158,34 @@ private[sql] object ProtobufUtils extends Logging {
}
/**
- * Builds Protobuf message descriptor either from the Java class or from
serialized descriptor
- * read from the file.
+ * Builds a Protobuf descriptor along with an ExtensionRegistry containing
all extensions found
+ * in the FileDescriptorSet.
+ *
* @param messageName
- * Protobuf message name or Java class name (when binaryFileDescriptorSet
is None)..
+ * Protobuf message name or Java class name (when binaryFileDescriptorSet
is None).
* @param binaryFileDescriptorSet
- * When the binary `FileDescriptorSet` is provided, the descriptor and its
dependencies are
- * read from it.
+ * When the binary `FileDescriptorSet` is provided, the descriptor,
extensions, and registry
+ * are read from it. When None, the descriptor is loaded from the Java
class and extensions
+ * are empty.
* @return
+ * DescriptorWithExtensions containing descriptor, registry, and extension
fields map
*/
- def buildDescriptor(messageName: String, binaryFileDescriptorSet:
Option[Array[Byte]])
- : Descriptor = {
+ def buildDescriptor(
+ messageName: String,
+ binaryFileDescriptorSet: Option[Array[Byte]]): DescriptorWithExtensions =
binaryFileDescriptorSet match {
- case Some(bytes) => buildDescriptor(bytes, messageName)
+ case Some(bytes) => buildDescriptorFromFDS(messageName, bytes)
case None => buildDescriptorFromJavaClass(messageName)
}
- }
/**
- * Loads the given protobuf class and returns Protobuf descriptor for it.
+ * Loads the given protobuf class and returns Protobuf descriptor for it.
+ *
+ * Given a Java class, we can only access the descriptor for the proto file
it is defined
+ * in. Extensions in other files will not be picked up. As such, we choose
to disable
+ * extension support when we fall back to the Java class.
*/
- def buildDescriptorFromJavaClass(protobufClassName: String): Descriptor = {
+ def buildDescriptorFromJavaClass(protobufClassName: String):
DescriptorWithExtensions = {
// Default 'Message' class here is shaded while using the package (as in
production).
// The incoming classes might not be shaded. Check both.
@@ -203,23 +233,32 @@ private[sql] object ProtobufUtils extends Logging {
protobufClassName, "Could not find getDescriptor() method", e)
}
- getDescriptorMethod
- .invoke(null)
- .asInstanceOf[Descriptor]
+ val descriptor = getDescriptorMethod.invoke(null).asInstanceOf[Descriptor]
+
+ DescriptorWithExtensions(descriptor, ExtensionRegistry.getEmptyRegistry,
Map.empty)
}
- def buildDescriptor(binaryFileDescriptorSet: Array[Byte], messageName:
String): Descriptor = {
+ def buildDescriptorFromFDS(
+ messageName: String,
+ binaryFileDescriptorSet: Array[Byte]): DescriptorWithExtensions = {
// Find the first message descriptor that matches the name.
- val descriptorOpt = parseFileDescriptorSet(binaryFileDescriptorSet)
+ val fileDescriptors = parseFileDescriptorSet(binaryFileDescriptorSet)
+ val descriptor = fileDescriptors
.flatMap { fileDesc =>
fileDesc.getMessageTypes.asScala.find { desc =>
desc.getName == messageName || desc.getFullName == messageName
}
- }.headOption
+ }
+ .headOption
+ .getOrElse {
+ throw
QueryCompilationErrors.unableToLocateProtobufMessageError(messageName)
+ }
+ val (extensionRegistry, fullNamesToExtensions) =
buildExtensionRegistry(fileDescriptors)
- descriptorOpt match {
- case Some(d) => d
- case None => throw
QueryCompilationErrors.unableToLocateProtobufMessageError(messageName)
+ if (SQLConf.get.getConf(SQLConf.PROTOBUF_EXTENSIONS_SUPPORT_ENABLED)) {
+ DescriptorWithExtensions(descriptor, extensionRegistry,
fullNamesToExtensions)
+ } else {
+ DescriptorWithExtensions(descriptor, ExtensionRegistry.getEmptyRegistry,
Map.empty)
}
}
@@ -232,38 +271,49 @@ private[sql] object ProtobufUtils extends Logging {
throw QueryCompilationErrors.descriptorParseError(ex)
}
val fileDescriptorProtoIndex = createDescriptorProtoMap(fileDescriptorSet)
+
+ // Mutated across invocations of buildFileDescriptor.
+ val builtDescriptors = mutable.Map[String, Descriptors.FileDescriptor]()
val fileDescriptorList: List[Descriptors.FileDescriptor] =
- fileDescriptorSet.getFileList.asScala.map( fileDescriptorProto =>
- buildFileDescriptor(fileDescriptorProto, fileDescriptorProtoIndex)
- ).toList
+ fileDescriptorSet.getFileList.asScala.map { fileDescriptorProto =>
+ buildFileDescriptor(fileDescriptorProto, fileDescriptorProtoIndex,
builtDescriptors)
+ }.distinctBy(_.getFullName).toList
fileDescriptorList
}
/**
- * Recursively constructs file descriptors for all dependencies for given
- * FileDescriptorProto and return.
+ * Recursively constructs file descriptors for all dependencies for given
FileDescriptorProto
+ * and return.
*/
private def buildFileDescriptor(
- fileDescriptorProto: FileDescriptorProto,
- fileDescriptorProtoMap: Map[String, FileDescriptorProto]):
Descriptors.FileDescriptor = {
- val fileDescriptorList =
fileDescriptorProto.getDependencyList().asScala.map { dependency =>
- fileDescriptorProtoMap.get(dependency) match {
- case Some(dependencyProto) =>
- if (dependencyProto.getName == "google/protobuf/any.proto"
- && dependencyProto.getPackage == "google.protobuf") {
- // For Any, use the descriptor already included as part of the
Java dependency.
- // Without this, JsonFormat used for converting Any fields fails
when
- // an Any field in input is set to `Any.getDefaultInstance()`.
- com.google.protobuf.AnyProto.getDescriptor
- // Should we do the same for timestamp.proto and empty.proto?
- } else {
- buildFileDescriptor(dependencyProto, fileDescriptorProtoMap)
- }
- case None =>
- throw
QueryCompilationErrors.protobufDescriptorDependencyError(dependency)
- }
- }
- Descriptors.FileDescriptor.buildFrom(fileDescriptorProto,
fileDescriptorList.toArray)
+ fileDescriptorProto: FileDescriptorProto,
+ fileDescriptorProtoMap: Map[String, FileDescriptorProto],
+ builtDescriptors: mutable.Map[String, Descriptors.FileDescriptor])
+ : Descriptors.FileDescriptor = {
+ // Storing references to constructed descriptors is crucial because
descriptors are compared
+ // by reference inside in the Protobuf library.
+ builtDescriptors.getOrElseUpdate(
+ fileDescriptorProto.getName, {
+ val fileDescriptorList =
fileDescriptorProto.getDependencyList().asScala.map {
+ dependency =>
+ fileDescriptorProtoMap.get(dependency) match {
+ case Some(dependencyProto) =>
+ if (dependencyProto.getName == "google/protobuf/any.proto"
+ && dependencyProto.getPackage == "google.protobuf") {
+ // For Any, use the descriptor already included as part of
the Java dependency.
+ // Without this, JsonFormat used for converting Any fields
fails when
+ // an Any field in input is set to
`Any.getDefaultInstance()`.
+ com.google.protobuf.AnyProto.getDescriptor
+ // Should we do the same for timestamp.proto and empty.proto?
+ } else {
+ buildFileDescriptor(dependencyProto, fileDescriptorProtoMap,
builtDescriptors)
+ }
+ case None =>
+ throw
QueryCompilationErrors.protobufDescriptorDependencyError(dependency)
+ }
+ }
+ Descriptors.FileDescriptor.buildFrom(fileDescriptorProto,
fileDescriptorList.toArray)
+ })
}
/**
@@ -296,9 +346,60 @@ private[sql] object ProtobufUtils extends Logging {
}
/** Builds [[TypeRegistry]] with the descriptor and the others from the same
proto file. */
- private [protobuf] def buildTypeRegistry(descriptor: Descriptor):
TypeRegistry = {
- TypeRegistry.newBuilder()
+ private[protobuf] def buildTypeRegistry(descriptor: Descriptor):
TypeRegistry = {
+ TypeRegistry
+ .newBuilder()
.add(descriptor) // This adds any other descriptors in the associated
proto file.
.build()
}
+
+ /**
+ * Builds an ExtensionRegistry and an index from full name to field
descriptor for all extensions
+ * found in the list of provided file descriptors.
+ *
+ * This method will traverse the AST to ensure extensions in nested scopes
are registered as well.
+ *
+ * @param fileDescriptors
+ * List of all file descriptors to process
+ * @return
+ * The populated ExtensionRegistry and a map from message full names to
extension fields
+ * sorted by field number
+ */
+ private def buildExtensionRegistry(fileDescriptors:
List[Descriptors.FileDescriptor])
+ : (ExtensionRegistry, Map[String, Seq[FieldDescriptor]]) = {
+ val registry = ExtensionRegistry.newInstance()
+ val fullNameToExtensions = mutable.Map[String,
ArrayBuffer[FieldDescriptor]]()
+
+ // Adds an extension to both the registry and map.
+ def addExtension(extField: FieldDescriptor): Unit = {
+ val extendeeName = extField.getContainingType.getFullName
+ // For message-type extensions, we need to provide a default instance.
+ if (extField.getJavaType == JavaType.MESSAGE) {
+ val defaultInstance =
DynamicMessage.getDefaultInstance(extField.getMessageType)
+ registry.add(extField, defaultInstance)
+ } else {
+ registry.add(extField)
+ }
+ fullNameToExtensions
+ .getOrElseUpdate(extendeeName, mutable.ArrayBuffer())
+ .append(extField)
+ }
+
+ for (fileDesc <- fileDescriptors) {
+ fileDesc.getExtensions.asScala.foreach(addExtension)
+
+ // Recursively add nested extensions.
+ def collectNestedExtensions(msgDesc: Descriptor): Unit = {
+ msgDesc.getExtensions.asScala.foreach(addExtension)
+ msgDesc.getNestedTypes.asScala.foreach(collectNestedExtensions)
+ }
+ fileDesc.getMessageTypes.asScala.foreach(collectNestedExtensions)
+ }
+
+ // Sort extension fields by field number for consistent ordering.
+ val sortedMap = fullNameToExtensions.map { case (name, extensions) =>
+ name -> extensions.sortBy(_.getNumber).toSeq
+ }.toMap
+ (registry, sortedMap)
+ }
}
diff --git
a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/SchemaConverters.scala
b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/SchemaConverters.scala
index 3273f473ed5c..5a73d9aaa72b 100644
---
a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/SchemaConverters.scala
+++
b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/SchemaConverters.scala
@@ -38,34 +38,43 @@ object SchemaConverters extends Logging {
/**
* Converts an Protobuf schema to a corresponding Spark SQL schema.
*
+ * fullNamesToExtensions is a map from full message names to the field
descriptors of their
+ * extensions. It is non-empty when proto2 extension support is enabled.
+ *
* @since 3.4.0
*/
private[protobuf] def toSqlType(
descriptor: Descriptor,
+ fullNamesToExtensions: Map[String, Seq[FieldDescriptor]] = Map.empty,
protobufOptions: ProtobufOptions = ProtobufOptions(Map.empty)):
SchemaType = {
- toSqlTypeHelper(descriptor, protobufOptions)
- }
+ val existingRecordNames = Map(descriptor.getFullName -> 1)
+ val regularFields = descriptor.getFields.asScala
+ .flatMap(structFieldFor(_, existingRecordNames, fullNamesToExtensions,
protobufOptions))
+ .toSeq
+ val extensionFields =
fullNamesToExtensions.getOrElse(descriptor.getFullName, Seq.empty)
+ val extStructFields = extensionFields.flatMap(
+ structFieldFor(_, existingRecordNames, fullNamesToExtensions,
protobufOptions))
- private[protobuf] def toSqlTypeHelper(
- descriptor: Descriptor,
- protobufOptions: ProtobufOptions): SchemaType = {
- val fields = descriptor.getFields.asScala.flatMap(
- structFieldFor(_,
- Map(descriptor.getFullName -> 1),
- protobufOptions: ProtobufOptions)).toSeq
- if (fields.isEmpty && protobufOptions.retainEmptyMessage) {
+ val allFields = regularFields ++ extStructFields
+
+ if (allFields.isEmpty && protobufOptions.retainEmptyMessage) {
SchemaType(convertEmptyProtoToStructWithDummyField(descriptor.getFullName),
nullable = true)
- } else SchemaType(StructType(fields), nullable = true)
+ } else {
+ SchemaType(StructType(allFields), nullable = true)
+ }
}
// existingRecordNames: Map[String, Int] used to track the depth of
recursive fields and to
// ensure that the conversion of the protobuf message to a Spark SQL
StructType object does not
// exceed the maximum recursive depth specified by the
recursiveFieldMaxDepth option.
+ // fullNamesToExtensions: a map from Protobuf message full names to the
field descriptors of
+ // their extensions used to add extension fields to the generated
StructField.
// A return of None implies the field has reached the maximum allowed
recursive depth and
// should be dropped.
private def structFieldFor(
fd: FieldDescriptor,
existingRecordNames: Map[String, Int],
+ fullNamesToExtensions: Map[String, Seq[FieldDescriptor]],
protobufOptions: ProtobufOptions): Option[StructField] = {
import com.google.protobuf.Descriptors.FieldDescriptor.JavaType._
@@ -150,16 +159,12 @@ object SchemaConverters extends Logging {
field.getName match {
case "key" =>
keyType =
- structFieldFor(
- field,
- existingRecordNames,
- protobufOptions).map(_.dataType)
+ structFieldFor(field, existingRecordNames,
fullNamesToExtensions, protobufOptions)
+ .map(_.dataType)
case "value" =>
valueType =
- structFieldFor(
- field,
- existingRecordNames,
- protobufOptions).map(_.dataType)
+ structFieldFor(field, existingRecordNames,
fullNamesToExtensions, protobufOptions)
+ .map(_.dataType)
}
}
(keyType, valueType) match {
@@ -206,10 +211,16 @@ object SchemaConverters extends Logging {
None
} else {
val newRecordNames = existingRecordNames + (recordName ->
(recursiveDepth + 1))
- val fields = fd.getMessageType.getFields.asScala.flatMap(
- structFieldFor(_, newRecordNames, protobufOptions)
- ).toSeq
- fields match {
+ // Get regular fields for the nested message
+ val regularFields = fd.getMessageType.getFields.asScala
+ .flatMap(structFieldFor(_, newRecordNames, fullNamesToExtensions,
protobufOptions))
+ .toSeq
+ // Get extension fields for the nested message
+ val nestedExtensions = fullNamesToExtensions.getOrElse(recordName,
Seq.empty)
+ val extFields = nestedExtensions.flatMap(
+ structFieldFor(_, newRecordNames, fullNamesToExtensions,
protobufOptions))
+ val allFields = regularFields ++ extFields
+ allFields match {
case Nil =>
if (protobufOptions.retainEmptyMessage) {
Some(convertEmptyProtoToStructWithDummyField(fd.getFullName))
diff --git
a/connector/protobuf/src/test/resources/protobuf/proto2_messages.proto
b/connector/protobuf/src/test/resources/protobuf/proto2_messages.proto
index 75525715a3d5..2d909bd4d8bd 100644
--- a/connector/protobuf/src/test/resources/protobuf/proto2_messages.proto
+++ b/connector/protobuf/src/test/resources/protobuf/proto2_messages.proto
@@ -55,3 +55,72 @@ message Proto2AllTypes {
}
map<string, string> map = 13;
}
+
+// Message with extension range for testing proto2 extension support
+message Proto2ExtensionTest {
+ optional string name = 1;
+ optional int32 id = 2;
+
+ extensions 100 to 200;
+}
+
+// Identical to Proto2ExtensionTest but without any extensions.
+// Used for schema evolution testing.
+message Proto2ExtensionTestBase {
+ optional string name = 1;
+ optional int32 id = 2;
+}
+
+// File-level extensions for Proto2ExtensionTest
+extend Proto2ExtensionTest {
+ optional string extension_string = 100;
+ optional int32 extension_int = 101;
+ optional NestedExtensionMessage extension_message = 102;
+ repeated int32 extension_repeated_int = 103;
+ optional Proto2AllTypes.NestedEnum extension_enum = 104;
+}
+
+// Nested message type used in an extension field
+message NestedExtensionMessage {
+ optional string nested_value = 1;
+ optional int32 nested_id = 2;
+}
+
+message ContainerWithNestedExtension {
+ optional string container_name = 1;
+
+ // Extension within the scope of another message.
+ extend Proto2ExtensionTest {
+ optional bool nested_extension_bool = 150;
+ }
+}
+
+// Test nested message extensions: a top-level message containing a nested
+// message type that itself has extensions.
+message MessageWithExtendableNested {
+ optional string top_level_name = 1;
+ optional ExtendableNestedMessage nested = 2;
+}
+
+message ExtendableNestedMessage {
+ optional string nested_name = 1;
+ optional int32 nested_id = 2;
+
+ extensions 100 to 200;
+}
+
+extend ExtendableNestedMessage {
+ optional string nested_ext_field = 100;
+ optional int32 nested_ext_int = 101;
+}
+
+// Message for testing extension field name collision
+message Proto2ExtensionCollisionTest {
+ optional string name = 1;
+ extensions 100 to 200;
+}
+
+// Extension that collides with "name" field (case-insensitive)
+extend Proto2ExtensionCollisionTest {
+ optional string NAME = 100;
+}
diff --git
a/connector/protobuf/src/test/resources/protobuf/proto2_messages_ext.proto
b/connector/protobuf/src/test/resources/protobuf/proto2_messages_ext.proto
new file mode 100644
index 000000000000..40ae8008d544
--- /dev/null
+++ b/connector/protobuf/src/test/resources/protobuf/proto2_messages_ext.proto
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+syntax = "proto2";
+
+package org.apache.spark.sql.protobuf.protos;
+
+import "proto2_messages.proto";
+
+option java_outer_classname = "Proto2MessagesExt";
+
+message Proto2ExtensionCrossFile {
+ optional int32 foo = 1;
+
+ extend Proto2ExtensionTest {
+ optional Proto2ExtensionCrossFile cross_file_extension = 200;
+ }
+}
diff --git
a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufCatalystDataConversionSuite.scala
b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufCatalystDataConversionSuite.scala
index 1802bceee1df..5f89138d7248 100644
---
a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufCatalystDataConversionSuite.scala
+++
b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufCatalystDataConversionSuite.scala
@@ -152,7 +152,8 @@ class ProtobufCatalystDataConversionSuite
expected: Option[Any],
filters: StructFilters = new NoopFilters): Unit = {
- val descriptor = ProtobufUtils.buildDescriptor(descFileBytes, messageName)
+ val descriptor =
+ ProtobufUtils.buildDescriptor(messageName,
Some(descFileBytes)).descriptor
val dataType = SchemaConverters.toSqlType(descriptor).dataType
val deserializer = new ProtobufDeserializer(descriptor, dataType, filters)
@@ -163,6 +164,7 @@ class ProtobufCatalystDataConversionSuite
// Verify Java class deserializer matches with descriptor based serializer.
val javaDescriptor = ProtobufUtils
.buildDescriptorFromJavaClass(s"$javaClassNamePrefix$messageName")
+ .descriptor
assert(dataType == SchemaConverters.toSqlType(javaDescriptor).dataType)
val javaDeserialized = new ProtobufDeserializer(javaDescriptor, dataType,
filters)
.deserialize(DynamicMessage.parseFrom(javaDescriptor, data.toByteArray))
@@ -199,7 +201,8 @@ class ProtobufCatalystDataConversionSuite
.add("name", "string")
.add("age", "int")
- val descriptor = ProtobufUtils.buildDescriptor(testFileDesc, "Person")
+ val descriptor =
+ ProtobufUtils.buildDescriptor("Person", Some(testFileDesc)).descriptor
val dynamicMessage = DynamicMessage
.newBuilder(descriptor)
.setField(descriptor.findFieldByName("name"), "Maxim")
@@ -237,11 +240,13 @@ class ProtobufCatalystDataConversionSuite
}
test("Full names for message using descriptor file") {
- val withShortName = ProtobufUtils.buildDescriptor(testFileDesc, "BytesMsg")
+ val withShortName =
+ ProtobufUtils.buildDescriptor("BytesMsg", Some(testFileDesc)).descriptor
assert(withShortName.findFieldByName("bytes_type") != null)
- val withFullName = ProtobufUtils.buildDescriptor(
- testFileDesc, "org.apache.spark.sql.protobuf.protos.BytesMsg")
+ val withFullName = ProtobufUtils
+ .buildDescriptor("org.apache.spark.sql.protobuf.protos.BytesMsg",
Some(testFileDesc))
+ .descriptor
assert(withFullName.findFieldByName("bytes_type") != null)
}
diff --git
a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufExtensionsSuite.scala
b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufExtensionsSuite.scala
new file mode 100644
index 000000000000..10c86c556cc5
--- /dev/null
+++
b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufExtensionsSuite.scala
@@ -0,0 +1,538 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.protobuf
+
+import com.google.protobuf.DescriptorProtos.FileDescriptorSet
+import com.google.protobuf.DynamicMessage
+
+import org.apache.spark.SparkException
+import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
+import
org.apache.spark.sql.internal.SQLConf.PROTOBUF_EXTENSIONS_SUPPORT_ENABLED
+import org.apache.spark.sql.protobuf.utils.ProtobufUtils
+import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.sql.types._
+import org.apache.spark.sql.util.{ProtobufUtils => CommonProtobufUtils}
+
+class ProtobufExtensionsSuite
+ extends QueryTest
+ with SharedSparkSession
+ with ProtobufTestBase
+ with Serializable {
+
+ import testImplicits._
+
+ val proto2FileDescFile = protobufDescriptorFile("proto2_messages.desc")
+ val proto2FileDesc =
CommonProtobufUtils.readDescriptorFileContent(proto2FileDescFile)
+ private val proto2JavaClassNamePrefix =
"org.apache.spark.sql.protobuf.protos.Proto2Messages$"
+
+ test("SPARK-55062: roundtrip - proto2 extension basic types") {
+ withExtensionSupport {
+ val descWithExt = ProtobufUtils.buildDescriptor("Proto2ExtensionTest",
Some(proto2FileDesc))
+ val descriptor = descWithExt.descriptor
+ val extensionFields =
descWithExt.fullNamesToExtensions(descriptor.getFullName)
+
+ val extStringField = extensionFields.find(_.getName ==
"extension_string").get
+ val extIntField = extensionFields.find(_.getName == "extension_int").get
+
+ val message = DynamicMessage
+ .newBuilder(descriptor)
+ .setField(descriptor.findFieldByName("name"), "test_name")
+ .setField(descriptor.findFieldByName("id"), 42)
+ .setField(extStringField, "ext_value")
+ .setField(extIntField, 123)
+ .build()
+
+ val df = Seq(message.toByteArray).toDF("value")
+
+ val fromProtoDF = df.select(
+ functions.from_protobuf($"value", "Proto2ExtensionTest",
proto2FileDesc).as("value_from"))
+
+ // Verify extension field values are correct
+ val row = fromProtoDF.select($"value_from.*").collect().head
+ assert(row.getAs[String]("name") == "test_name")
+ assert(row.getAs[Int]("id") == 42)
+ assert(row.getAs[String]("extension_string") == "ext_value")
+ assert(row.getAs[Int]("extension_int") == 123)
+
+ val toProtoDF = fromProtoDF.select(
+ functions
+ .to_protobuf($"value_from", "Proto2ExtensionTest", proto2FileDesc)
+ .as("value_to"))
+ val toFromProtoDF = toProtoDF.select(
+ functions
+ .from_protobuf($"value_to", "Proto2ExtensionTest", proto2FileDesc)
+ .as("value_to_from"))
+
+ checkAnswer(fromProtoDF.select($"value_from.*"),
toFromProtoDF.select($"value_to_from.*"))
+ }
+ }
+
+ test("SPARK-55062: roundtrip - proto2 extension enum") {
+ withExtensionSupport {
+ val descWithExt = ProtobufUtils.buildDescriptor("Proto2ExtensionTest",
Some(proto2FileDesc))
+ val descriptor = descWithExt.descriptor
+ val extensionFields =
descWithExt.fullNamesToExtensions(descriptor.getFullName)
+
+ val extEnumField = extensionFields.find(_.getName ==
"extension_enum").get
+
+ val message = DynamicMessage
+ .newBuilder(descriptor)
+ .setField(descriptor.findFieldByName("name"), "enum_test")
+ .setField(descriptor.findFieldByName("id"), 99)
+ .setField(extEnumField,
extEnumField.getEnumType.findValueByName("FIRST"))
+ .build()
+
+ val df = Seq(message.toByteArray).toDF("value")
+
+ val fromProtoDF = df.select(
+ functions.from_protobuf($"value", "Proto2ExtensionTest",
proto2FileDesc).as("value_from"))
+
+ // Verify extension field value is correct
+ val row = fromProtoDF.select($"value_from.*").collect().head
+ assert(row.getAs[String]("name") == "enum_test")
+ assert(row.getAs[Int]("id") == 99)
+ assert(row.getAs[String]("extension_enum") == "FIRST")
+
+ val toProtoDF = fromProtoDF.select(
+ functions
+ .to_protobuf($"value_from", "Proto2ExtensionTest", proto2FileDesc)
+ .as("value_to"))
+ val toFromProtoDF = toProtoDF.select(
+ functions
+ .from_protobuf($"value_to", "Proto2ExtensionTest", proto2FileDesc)
+ .as("value_to_from"))
+
+ checkAnswer(fromProtoDF.select($"value_from.*"),
toFromProtoDF.select($"value_to_from.*"))
+ }
+ }
+
+ test("SPARK-55062: roundtrip - proto2 extension nested message") {
+ withExtensionSupport {
+ val descWithExt = ProtobufUtils.buildDescriptor("Proto2ExtensionTest",
Some(proto2FileDesc))
+ val descriptor = descWithExt.descriptor
+ val extensionFields = descWithExt.fullNamesToExtensions
+ .getOrElse(descriptor.getFullName, Seq.empty)
+
+ val extMessageField = extensionFields.find(_.getName ==
"extension_message").get
+
+ val nestedMessage = DynamicMessage
+ .newBuilder(extMessageField.getMessageType)
+
.setField(extMessageField.getMessageType.findFieldByName("nested_value"),
"nested_test")
+ .setField(extMessageField.getMessageType.findFieldByName("nested_id"),
99)
+ .build()
+
+ val message = DynamicMessage
+ .newBuilder(descriptor)
+ .setField(descriptor.findFieldByName("name"), "main")
+ .setField(extMessageField, nestedMessage)
+ .build()
+
+ val df = Seq(message.toByteArray).toDF("value")
+
+ val fromProtoDF = df.select(
+ functions.from_protobuf($"value", "Proto2ExtensionTest",
proto2FileDesc).as("value_from"))
+
+ // Verify extension field value is correct
+ val row = fromProtoDF.select($"value_from.*").collect().head
+ assert(row.getAs[String]("name") == "main")
+ val nestedRow = row.getAs[Row]("extension_message")
+ assert(nestedRow.getAs[String]("nested_value") == "nested_test")
+ assert(nestedRow.getAs[Int]("nested_id") == 99)
+
+ val toProtoDF = fromProtoDF.select(
+ functions
+ .to_protobuf($"value_from", "Proto2ExtensionTest", proto2FileDesc)
+ .as("value_to"))
+ val toFromProtoDF = toProtoDF.select(
+ functions
+ .from_protobuf($"value_to", "Proto2ExtensionTest", proto2FileDesc)
+ .as("value_to_from"))
+
+ checkAnswer(fromProtoDF.select($"value_from.*"),
toFromProtoDF.select($"value_to_from.*"))
+ }
+ }
+
+ test("SPARK-55062: roundtrip - proto2 extension repeated") {
+ withExtensionSupport {
+ val descWithExt = ProtobufUtils.buildDescriptor("Proto2ExtensionTest",
Some(proto2FileDesc))
+ val descriptor = descWithExt.descriptor
+ val extensionFields = descWithExt.fullNamesToExtensions
+ .getOrElse(descriptor.getFullName, Seq.empty)
+
+ val extRepeatedField = extensionFields.find(_.getName ==
"extension_repeated_int").get
+
+ val message = DynamicMessage
+ .newBuilder(descriptor)
+ .setField(descriptor.findFieldByName("name"), "repeated_test")
+ .addRepeatedField(extRepeatedField, 1)
+ .addRepeatedField(extRepeatedField, 2)
+ .addRepeatedField(extRepeatedField, 3)
+ .build()
+
+ val df = Seq(message.toByteArray).toDF("value")
+
+ val fromProtoDF = df.select(
+ functions.from_protobuf($"value", "Proto2ExtensionTest",
proto2FileDesc).as("value_from"))
+
+ // Verify extension field value is correct
+ val row = fromProtoDF.select($"value_from.*").collect().head
+ assert(row.getAs[String]("name") == "repeated_test")
+ val repeatedValues =
row.getSeq[Int](row.fieldIndex("extension_repeated_int"))
+ assert(repeatedValues == Seq(1, 2, 3))
+
+ val toProtoDF = fromProtoDF.select(
+ functions
+ .to_protobuf($"value_from", "Proto2ExtensionTest", proto2FileDesc)
+ .as("value_to"))
+ val toFromProtoDF = toProtoDF.select(
+ functions
+ .from_protobuf($"value_to", "Proto2ExtensionTest", proto2FileDesc)
+ .as("value_to_from"))
+
+ checkAnswer(fromProtoDF.select($"value_from.*"),
toFromProtoDF.select($"value_to_from.*"))
+ }
+ }
+
+ test("SPARK-55062: roundtrip - proto2 extension defined in another message")
{
+ withExtensionSupport {
+ val descWithExt = ProtobufUtils.buildDescriptor("Proto2ExtensionTest",
Some(proto2FileDesc))
+ val descriptor = descWithExt.descriptor
+ val extensionFields = descWithExt.fullNamesToExtensions
+ .getOrElse(descriptor.getFullName, Seq.empty)
+
+ val nestedExtBoolField = extensionFields.find(_.getName ==
"nested_extension_bool")
+ assert(
+ nestedExtBoolField.isDefined,
+ "Should find extension defined in message
ContainerWithNestedExtension")
+
+ val message = DynamicMessage
+ .newBuilder(descriptor)
+ .setField(descriptor.findFieldByName("name"), "nested_ext_test")
+ .setField(nestedExtBoolField.get, true)
+ .build()
+
+ val df = Seq(message.toByteArray).toDF("value")
+
+ val fromProtoDF = df.select(
+ functions.from_protobuf($"value", "Proto2ExtensionTest",
proto2FileDesc).as("value_from"))
+
+ // Verify extension field value is correct
+ val row = fromProtoDF.select($"value_from.*").collect().head
+ assert(row.getAs[String]("name") == "nested_ext_test")
+ assert(row.getAs[Boolean]("nested_extension_bool"))
+
+ val toProtoDF = fromProtoDF.select(
+ functions
+ .to_protobuf($"value_from", "Proto2ExtensionTest", proto2FileDesc)
+ .as("value_to"))
+ val toFromProtoDF = toProtoDF.select(
+ functions
+ .from_protobuf($"value_to", "Proto2ExtensionTest", proto2FileDesc)
+ .as("value_to_from"))
+
+ checkAnswer(fromProtoDF.select($"value_from.*"),
toFromProtoDF.select($"value_to_from.*"))
+ }
+ }
+
+ test("SPARK-55062: roundtrip - proto2 extension on nested message") {
+ withExtensionSupport {
+ val descWithExt =
+ ProtobufUtils.buildDescriptor("MessageWithExtendableNested",
Some(proto2FileDesc))
+ val topLevelDescriptor = descWithExt.descriptor
+
+ val nestedFieldDesc = topLevelDescriptor.findFieldByName("nested")
+ val nestedMsgDescriptor = nestedFieldDesc.getMessageType
+ val nestedExtensions = descWithExt.fullNamesToExtensions
+ .getOrElse(nestedMsgDescriptor.getFullName, Seq.empty)
+
+ assert(
+ nestedExtensions.nonEmpty,
+ "Should find extensions for nested message type
ExtendableNestedMessage")
+ val nestedExtField = nestedExtensions.find(_.getName ==
"nested_ext_field").get
+ val nestedExtInt = nestedExtensions.find(_.getName ==
"nested_ext_int").get
+
+ val nestedMessage = DynamicMessage
+ .newBuilder(nestedMsgDescriptor)
+ .setField(nestedMsgDescriptor.findFieldByName("nested_name"),
"nested_name_value")
+ .setField(nestedMsgDescriptor.findFieldByName("nested_id"), 42)
+ .setField(nestedExtField, "ext_field_value")
+ .setField(nestedExtInt, 123)
+ .build()
+
+ val topLevelMessage = DynamicMessage
+ .newBuilder(topLevelDescriptor)
+ .setField(topLevelDescriptor.findFieldByName("top_level_name"),
"top_name")
+ .setField(nestedFieldDesc, nestedMessage)
+ .build()
+
+ val df = Seq(topLevelMessage.toByteArray).toDF("value")
+
+ val fromProtoDF = df.select(
+ functions
+ .from_protobuf($"value", "MessageWithExtendableNested",
proto2FileDesc)
+ .as("value_from"))
+
+ // Verify nested extension field values are correct
+ val row = fromProtoDF.select($"value_from.*").collect().head
+ assert(row.getAs[String]("top_level_name") == "top_name")
+ val nestedRow = row.getAs[Row]("nested")
+ assert(nestedRow.getAs[String]("nested_name") == "nested_name_value")
+ assert(nestedRow.getAs[Int]("nested_id") == 42)
+ assert(nestedRow.getAs[String]("nested_ext_field") == "ext_field_value")
+ assert(nestedRow.getAs[Int]("nested_ext_int") == 123)
+
+ // Verify roundtrip preserves values
+ val toProtoDF = fromProtoDF.select(
+ functions
+ .to_protobuf($"value_from", "MessageWithExtendableNested",
proto2FileDesc)
+ .as("value_to"))
+ val toFromProtoDF = toProtoDF.select(
+ functions
+ .from_protobuf($"value_to", "MessageWithExtendableNested",
proto2FileDesc)
+ .as("value_to_from"))
+
+ checkAnswer(fromProtoDF.select($"value_from.*"),
toFromProtoDF.select($"value_to_from.*"))
+ }
+ }
+
+ test("SPARK-55062: roundtrip - proto2 extension in separate file") {
+ withExtensionSupport {
+ // NB: We manually construct a merged descriptor file because Maven's
Protobuf plugin
+ // generates a single descriptor file for each .proto file.
+ val extDescriptorPath =
protobufDescriptorFile("proto2_messages_ext.desc")
+ val fdsExt = FileDescriptorSet.parseFrom(
+ CommonProtobufUtils.readDescriptorFileContent(extDescriptorPath))
+ val fds = FileDescriptorSet.parseFrom(proto2FileDesc)
+ val combinedFds = FileDescriptorSet.newBuilder
+ .addAllFile(fdsExt.getFileList)
+ .addAllFile(fds.getFileList)
+ .build()
+ val combinedFdsBytes = combinedFds.toByteArray
+
+ val descWithExt =
+ ProtobufUtils.buildDescriptor("Proto2ExtensionTest",
Some(combinedFds.toByteArray))
+ val descriptor = descWithExt.descriptor
+ val extensionFields = descWithExt.fullNamesToExtensions
+ .getOrElse(descriptor.getFullName, Seq.empty)
+
+ val extMessageField = extensionFields.find(_.getName ==
"cross_file_extension").get
+
+ val nestedMessage = DynamicMessage
+ .newBuilder(extMessageField.getMessageType)
+ .setField(extMessageField.getMessageType.findFieldByName("foo"), 1)
+ .build()
+
+ val message = DynamicMessage
+ .newBuilder(descriptor)
+ .setField(descriptor.findFieldByName("name"), "main")
+ .setField(extMessageField, nestedMessage)
+ .build()
+
+ val df = Seq(message.toByteArray).toDF("value")
+
+ val fromProtoDF = df.select(
+ functions
+ .from_protobuf($"value", "Proto2ExtensionTest", combinedFdsBytes)
+ .as("value_from"))
+
+ // Verify extension field value is correct
+ val row = fromProtoDF.select($"value_from.*").collect().head
+ assert(row.getAs[String]("name") == "main")
+ val crossFileExtRow = row.getAs[Row]("cross_file_extension")
+ assert(crossFileExtRow.getAs[Int]("foo") == 1)
+
+ val toProtoDF = fromProtoDF.select(
+ functions
+ .to_protobuf($"value_from", "Proto2ExtensionTest", combinedFdsBytes)
+ .as("value_to"))
+ val toFromProtoDF = toProtoDF.select(
+ functions
+ .from_protobuf($"value_to", "Proto2ExtensionTest", combinedFdsBytes)
+ .as("value_to_from"))
+
+ checkAnswer(fromProtoDF.select($"value_from.*"),
toFromProtoDF.select($"value_to_from.*"))
+ }
+ }
+
+ test("SPARK-55062: proto2 extension field name collision with regular
field") {
+ withExtensionSupport {
+ val descriptor =
+ ProtobufUtils
+ .buildDescriptor("Proto2ExtensionCollisionTest",
Some(proto2FileDesc))
+ .descriptor
+ val message = DynamicMessage
+ .newBuilder(descriptor)
+ .setField(descriptor.findFieldByName("name"), "test")
+ .build()
+
+ val df = Seq(message.toByteArray).toDF("value")
+
+ // Some unwrapping required to get to the root error, as it is thrown
during execution.
+ val e = intercept[SparkException] {
+ df.select(
+ functions.from_protobuf($"value", "Proto2ExtensionCollisionTest",
proto2FileDesc))
+ .collect()
+ }
+ checkError(
+ exception = e.getCause
+ .asInstanceOf[AnalysisException]
+ .getCause
+ .asInstanceOf[AnalysisException],
+ condition = "PROTOBUF_FIELD_MISSING",
+ parameters = Map(
+ "field" -> "name",
+ "protobufSchema" -> "top-level record",
+ "matchSize" -> "2",
+ "matches" -> "[name, NAME]"))
+ }
+ }
+
+ test("SPARK-55062: schema evolution - data without extensions read with
extension schema") {
+ withExtensionSupport {
+ val descriptorWithoutExt =
+ ProtobufUtils.buildDescriptor("Proto2ExtensionTestBase",
Some(proto2FileDesc)).descriptor
+ val messageWithoutExtensions = DynamicMessage
+ .newBuilder(descriptorWithoutExt)
+ .setField(descriptorWithoutExt.findFieldByName("name"), "base_name")
+ .setField(descriptorWithoutExt.findFieldByName("id"), 42)
+ .build()
+
+ val df = Seq(messageWithoutExtensions.toByteArray).toDF("value")
+
+ // Read using schema WITH extensions
+ val fromProtoDF = df.select(
+ functions.from_protobuf($"value", "Proto2ExtensionTest",
proto2FileDesc).as("parsed"))
+ val result = fromProtoDF.select($"parsed.*").collect()
+ assert(result.length == 1)
+ val row = result(0)
+
+ // Regular fields have expected values, while extension fields are
null/empty
+ assert(row.getAs[String]("name") == "base_name")
+ assert(row.getAs[Int]("id") == 42)
+ assert(row.getAs[String]("extension_string") == null)
+ assert(row.isNullAt(row.fieldIndex("extension_int")))
+ assert(row.isNullAt(row.fieldIndex("extension_enum")))
+ assert(row.isNullAt(row.fieldIndex("extension_message")))
+ assert(row.isNullAt(row.fieldIndex("nested_extension_bool")))
+ assert(row.getSeq[Int](row.fieldIndex("extension_repeated_int")).isEmpty)
+ }
+ }
+
+ test("SPARK-55062: Java class fallback drops extensions") {
+ withExtensionSupport {
+ val descWithExt = ProtobufUtils.buildDescriptor("Proto2ExtensionTest",
Some(proto2FileDesc))
+ val descriptor = descWithExt.descriptor
+ val extensionFields =
descWithExt.fullNamesToExtensions(descriptor.getFullName)
+
+ val extStringField = extensionFields.find(_.getName ==
"extension_string").get
+
+ val message = DynamicMessage
+ .newBuilder(descriptor)
+ .setField(descriptor.findFieldByName("name"), "test_name")
+ .setField(descriptor.findFieldByName("id"), 42)
+ .setField(extStringField, "ext_value")
+ .build()
+ val df = Seq(message.toByteArray).toDF("value")
+
+ // Read using Java class, which does not support extensions
+ val fromProtoDF = df.select(
+ functions
+ .from_protobuf($"value", proto2JavaClassNamePrefix +
"Proto2ExtensionTest")
+ .as("parsed"))
+
+ // Schema should only have regular fields, not extension fields
+ val row = fromProtoDF.select($"parsed.*").collect().head
+ val schema = row.schema
+ assert(row.getAs[String]("name") == "test_name")
+ assert(row.getAs[Int]("id") == 42)
+ assert(!schema.fieldNames.contains("extension_string"))
+ assert(!schema.fieldNames.contains("extension_int"))
+ }
+ }
+
+ test("SPARK-55062: from_protobuf drops extension fields by default") {
+ // We need extensions on when calling buildDescriptor so that we can
access the extension
+ // fields for test setup and assertions.
+ val descWithExt = withExtensionSupport {
+ ProtobufUtils.buildDescriptor("Proto2ExtensionTest",
Some(proto2FileDesc))
+ }
+ val descriptor = descWithExt.descriptor
+ val extensionFields =
descWithExt.fullNamesToExtensions(descriptor.getFullName)
+
+ val extStringField = extensionFields.find(_.getName ==
"extension_string").get
+ val message = DynamicMessage
+ .newBuilder(descriptor)
+ .setField(descriptor.findFieldByName("name"), "test_name")
+ .setField(descriptor.findFieldByName("id"), 42)
+ .setField(extStringField, "ext_value")
+ .build()
+
+ val df = Seq(message.toByteArray).toDF("value")
+
+ val fromProtoDF = df.select(
+ functions
+ .from_protobuf($"value", "Proto2ExtensionTest", proto2FileDesc)
+ .as("value_from"))
+ val row = fromProtoDF.select($"value_from.*").collect().head
+ val schema = row.schema
+ assert(row.getAs[String]("name") == "test_name")
+ assert(row.getAs[Int]("id") == 42)
+ assert(!schema.fieldNames.contains("extension_string"))
+ assert(!schema.fieldNames.contains("extension_int"))
+ }
+
+ test("SPARK-55062: to_protobuf does not recognize extension fields by
default") {
+ val schema = StructType(
+ StructField(
+ "Proto2ExtensionTest",
+ StructType(
+ StructField("name", StringType) ::
+ StructField("id", IntegerType) ::
+ StructField("extension_string", StringType) ::
+ StructField("extension_int", IntegerType) :: Nil)) :: Nil)
+ val df = spark.createDataFrame(
+ spark.sparkContext.parallelize(Seq(Row(Row("test_name", 42,
"extension_value", 123)))),
+ schema)
+
+ // We expect an error because from to_protobuf's perspective, the
DataFrame contains extra
+ // fields that are not in the message descriptor.
+ val e = intercept[AnalysisException] {
+ val toProtoDF = df
+ .select(
+ functions
+ .to_protobuf($"Proto2ExtensionTest", "Proto2ExtensionTest",
proto2FileDesc)
+ .as("to_proto"))
+ .collect()
+ }
+ val toType = "\"STRUCT<name: STRING, id: INT, extension_string: STRING,
extension_int: INT>\""
+ checkError(
+ exception = e,
+ condition = "UNABLE_TO_CONVERT_TO_PROTOBUF_MESSAGE_TYPE",
+ parameters = Map("protobufType" -> "Proto2ExtensionTest", "toType" ->
toType))
+ }
+
+ // Extension support is disabled by default.
+ private def withExtensionSupport[T](f: => T): T = {
+ val old = conf.getConf(PROTOBUF_EXTENSIONS_SUPPORT_ENABLED)
+ conf.setConf(PROTOBUF_EXTENSIONS_SUPPORT_ENABLED, true)
+ try {
+ f
+ } finally {
+ conf.setConf(PROTOBUF_EXTENSIONS_SUPPORT_ENABLED, old)
+ }
+ }
+}
diff --git
a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala
b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala
index e46fcb1a1735..b283e4fb8fa9 100644
---
a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala
+++
b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala
@@ -163,8 +163,12 @@ class ProtobufFunctionsSuite extends QueryTest with
SharedSparkSession with Prot
}
test("roundtrip in from_protobuf and to_protobuf - Repeated Message Once") {
- val repeatedMessageDesc = ProtobufUtils.buildDescriptor(testFileDesc,
"RepeatedMessage")
- val basicMessageDesc = ProtobufUtils.buildDescriptor(testFileDesc,
"BasicMessage")
+ val repeatedMessageDesc = ProtobufUtils
+ .buildDescriptor("RepeatedMessage", Some(testFileDesc))
+ .descriptor
+ val basicMessageDesc = ProtobufUtils
+ .buildDescriptor("BasicMessage", Some(testFileDesc))
+ .descriptor
val basicMessage = DynamicMessage
.newBuilder(basicMessageDesc)
@@ -200,8 +204,12 @@ class ProtobufFunctionsSuite extends QueryTest with
SharedSparkSession with Prot
}
test("roundtrip in from_protobuf and to_protobuf - Repeated Message Twice") {
- val repeatedMessageDesc = ProtobufUtils.buildDescriptor(testFileDesc,
"RepeatedMessage")
- val basicMessageDesc = ProtobufUtils.buildDescriptor(testFileDesc,
"BasicMessage")
+ val repeatedMessageDesc = ProtobufUtils
+ .buildDescriptor("RepeatedMessage", Some(testFileDesc))
+ .descriptor
+ val basicMessageDesc = ProtobufUtils
+ .buildDescriptor("BasicMessage", Some(testFileDesc))
+ .descriptor
val basicMessage1 = DynamicMessage
.newBuilder(basicMessageDesc)
@@ -251,7 +259,9 @@ class ProtobufFunctionsSuite extends QueryTest with
SharedSparkSession with Prot
}
test("roundtrip in from_protobuf and to_protobuf - Map") {
- val messageMapDesc = ProtobufUtils.buildDescriptor(testFileDesc,
"SimpleMessageMap")
+ val messageMapDesc = ProtobufUtils
+ .buildDescriptor("SimpleMessageMap", Some(testFileDesc))
+ .descriptor
val mapStr1 = DynamicMessage
.newBuilder(messageMapDesc.findNestedTypeByName("StringMapdataEntry"))
@@ -345,8 +355,12 @@ class ProtobufFunctionsSuite extends QueryTest with
SharedSparkSession with Prot
}
test("roundtrip in from_protobuf and to_protobuf - Enum") {
- val messageEnumDesc = ProtobufUtils.buildDescriptor(testFileDesc,
"SimpleMessageEnum")
- val basicEnumDesc = ProtobufUtils.buildDescriptor(testFileDesc,
"BasicEnumMessage")
+ val messageEnumDesc = ProtobufUtils
+ .buildDescriptor("SimpleMessageEnum", Some(testFileDesc))
+ .descriptor
+ val basicEnumDesc = ProtobufUtils
+ .buildDescriptor("BasicEnumMessage", Some(testFileDesc))
+ .descriptor
val dynamicMessage = DynamicMessage
.newBuilder(messageEnumDesc)
@@ -389,9 +403,15 @@ class ProtobufFunctionsSuite extends QueryTest with
SharedSparkSession with Prot
}
test("round trip in from_protobuf and to_protobuf - Multiple Message") {
- val messageMultiDesc = ProtobufUtils.buildDescriptor(testFileDesc,
"MultipleExample")
- val messageIncludeDesc = ProtobufUtils.buildDescriptor(testFileDesc,
"IncludedExample")
- val messageOtherDesc = ProtobufUtils.buildDescriptor(testFileDesc,
"OtherExample")
+ val messageMultiDesc = ProtobufUtils
+ .buildDescriptor("MultipleExample", Some(testFileDesc))
+ .descriptor
+ val messageIncludeDesc = ProtobufUtils
+ .buildDescriptor("IncludedExample", Some(testFileDesc))
+ .descriptor
+ val messageOtherDesc = ProtobufUtils
+ .buildDescriptor("OtherExample", Some(testFileDesc))
+ .descriptor
val otherMessage = DynamicMessage
.newBuilder(messageOtherDesc)
@@ -470,8 +490,8 @@ class ProtobufFunctionsSuite extends QueryTest with
SharedSparkSession with Prot
val catalystTypesFile = protobufDescriptorFile("catalyst_types.desc")
val descBytes =
CommonProtobufUtils.readDescriptorFileContent(catalystTypesFile)
- val oldProducer = ProtobufUtils.buildDescriptor(descBytes, "oldProducer")
- val newConsumer = ProtobufUtils.buildDescriptor(descBytes, "newConsumer")
+ val oldProducer = ProtobufUtils.buildDescriptor("oldProducer",
Some(descBytes)).descriptor
+ val newConsumer = ProtobufUtils.buildDescriptor("newConsumer",
Some(descBytes)).descriptor
val oldProducerMessage = DynamicMessage
.newBuilder(oldProducer)
@@ -512,8 +532,8 @@ class ProtobufFunctionsSuite extends QueryTest with
SharedSparkSession with Prot
val catalystTypesFile = protobufDescriptorFile("catalyst_types.desc")
val descBytes =
CommonProtobufUtils.readDescriptorFileContent(catalystTypesFile)
- val newProducer = ProtobufUtils.buildDescriptor(descBytes, "newProducer")
- val oldConsumer = ProtobufUtils.buildDescriptor(descBytes, "oldConsumer")
+ val newProducer = ProtobufUtils.buildDescriptor("newProducer",
Some(descBytes)).descriptor
+ val oldConsumer = ProtobufUtils.buildDescriptor("oldConsumer",
Some(descBytes)).descriptor
val newProducerMessage = DynamicMessage
.newBuilder(newProducer)
@@ -560,7 +580,9 @@ class ProtobufFunctionsSuite extends QueryTest with
SharedSparkSession with Prot
val binary = toProtobuf.first().get(0).asInstanceOf[Array[Byte]]
- val messageDescriptor = ProtobufUtils.buildDescriptor(testFileDesc,
"requiredMsg")
+ val messageDescriptor = ProtobufUtils
+ .buildDescriptor("requiredMsg", Some(testFileDesc))
+ .descriptor
val actualMessage = DynamicMessage.parseFrom(messageDescriptor, binary)
assert(actualMessage.getField(messageDescriptor.findFieldByName("key"))
@@ -582,7 +604,9 @@ class ProtobufFunctionsSuite extends QueryTest with
SharedSparkSession with Prot
}
test("from_protobuf filter to_protobuf") {
- val basicMessageDesc = ProtobufUtils.buildDescriptor(testFileDesc,
"BasicMessage")
+ val basicMessageDesc = ProtobufUtils
+ .buildDescriptor("BasicMessage", Some(testFileDesc))
+ .descriptor
val basicMessage = DynamicMessage
.newBuilder(basicMessageDesc)
@@ -714,7 +738,7 @@ class ProtobufFunctionsSuite extends QueryTest with
SharedSparkSession with Prot
}
test("Verify OneOf field between from_protobuf -> to_protobuf and struct ->
from_protobuf") {
- val descriptor = ProtobufUtils.buildDescriptor(testFileDesc, "OneOfEvent")
+ val descriptor = ProtobufUtils.buildDescriptor("OneOfEvent",
Some(testFileDesc)).descriptor
val oneOfEvent = OneOfEvent.newBuilder()
.setKey("key")
.setCol1(123)
@@ -804,7 +828,7 @@ class ProtobufFunctionsSuite extends QueryTest with
SharedSparkSession with Prot
}
test("Verify recursion field with complex schema with
recursive.fields.max.depth") {
- val descriptor = ProtobufUtils.buildDescriptor(testFileDesc, "Employee")
+ val descriptor = ProtobufUtils.buildDescriptor("Employee",
Some(testFileDesc)).descriptor
val manager =
Employee.newBuilder().setFirstName("firstName").setLastName("lastName").build()
val em2 = EM2.newBuilder().setTeamsize(100).setEm2Manager(manager).build()
@@ -846,7 +870,9 @@ class ProtobufFunctionsSuite extends QueryTest with
SharedSparkSession with Prot
test("Verify OneOf field with recursive fields between from_protobuf ->
to_protobuf." +
"and struct -> from_protobuf") {
- val descriptor = ProtobufUtils.buildDescriptor(testFileDesc,
"OneOfEventWithRecursion")
+ val descriptor = ProtobufUtils
+ .buildDescriptor("OneOfEventWithRecursion", Some(testFileDesc))
+ .descriptor
val nestedTwo = OneOfEventWithRecursion.newBuilder()
.setKey("keyNested2").setValue("valueNested2").build()
diff --git
a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufSerdeSuite.scala
b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufSerdeSuite.scala
index f3bd49e1b24a..71bcb6b51a56 100644
---
a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufSerdeSuite.scala
+++
b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufSerdeSuite.scala
@@ -47,7 +47,8 @@ class ProtobufSerdeSuite extends SharedSparkSession with
ProtobufTestBase {
test("Test basic conversion") {
withFieldMatchType { fieldMatch =>
- val protoFile = ProtobufUtils.buildDescriptor(testFileDesc,
"SerdeBasicMessage")
+ val protoFile =
+ ProtobufUtils.buildDescriptor("SerdeBasicMessage",
Some(testFileDesc)).descriptor
val dynamicMessageFoo = DynamicMessage
.newBuilder(protoFile.getFile.findMessageTypeByName("Foo"))
@@ -71,7 +72,8 @@ class ProtobufSerdeSuite extends SharedSparkSession with
ProtobufTestBase {
// This test verifies that optional fields can be missing from input
Catalyst schema
// while serializing rows to protobuf.
- val desc = ProtobufUtils.buildDescriptor(proto2Desc,
"FoobarWithRequiredFieldBar")
+ val desc =
+ ProtobufUtils.buildDescriptor("FoobarWithRequiredFieldBar",
Some(proto2Desc)).descriptor
// Confirm desc contains optional field 'foo' and required field bar.
assert(desc.getFields.size() == 2)
@@ -90,7 +92,8 @@ class ProtobufSerdeSuite extends SharedSparkSession with
ProtobufTestBase {
}
test("Fail to convert with field type mismatch") {
- val protoFile = ProtobufUtils.buildDescriptor(testFileDesc,
"MissMatchTypeInRoot")
+ val protoFile =
+ ProtobufUtils.buildDescriptor("MissMatchTypeInRoot",
Some(testFileDesc)).descriptor
withFieldMatchType { fieldMatch =>
assertFailedConversionMessage(
protoFile,
@@ -113,7 +116,8 @@ class ProtobufSerdeSuite extends SharedSparkSession with
ProtobufTestBase {
}
test("Fail to convert with missing nested Protobuf fields for serializer") {
- val protoFile = ProtobufUtils.buildDescriptor(testFileDesc,
"FieldMissingInProto")
+ val protoFile =
+ ProtobufUtils.buildDescriptor("FieldMissingInProto",
Some(testFileDesc)).descriptor
val nonnullCatalyst = new StructType()
.add("foo", new StructType().add("bar", IntegerType, nullable = false))
@@ -142,7 +146,7 @@ class ProtobufSerdeSuite extends SharedSparkSession with
ProtobufTestBase {
test("Fail to convert with deeply nested field type mismatch") {
val protoFile = ProtobufUtils.buildDescriptorFromJavaClass(
s"${javaClassNamePrefix}MissMatchTypeInDeepNested"
- )
+ ).descriptor
val catalyst = new StructType().add("top", CATALYST_STRUCT)
withFieldMatchType { fieldMatch =>
@@ -169,12 +173,13 @@ class ProtobufSerdeSuite extends SharedSparkSession with
ProtobufTestBase {
}
test("Fail to convert with missing Catalyst fields") {
- val protoFile = ProtobufUtils.buildDescriptor(testFileDesc,
"FieldMissingInProto")
+ val protoFile =
+ ProtobufUtils.buildDescriptor("FieldMissingInProto",
Some(testFileDesc)).descriptor
val foobarSQLType = structFromDDL("struct<foo string>") // "bar" is
missing.
assertFailedConversionMessage(
- ProtobufUtils.buildDescriptor(proto2Desc, "FoobarWithRequiredFieldBar"),
+ ProtobufUtils.buildDescriptor("FoobarWithRequiredFieldBar",
Some(proto2Desc)).descriptor,
Serializer,
BY_NAME,
catalystSchema = foobarSQLType,
@@ -189,14 +194,17 @@ class ProtobufSerdeSuite extends SharedSparkSession with
ProtobufTestBase {
withFieldMatchType(Deserializer.create(CATALYST_STRUCT, protoFile, _))
val protoNestedFile = ProtobufUtils
- .buildDescriptor(proto2Desc, "NestedFoobarWithRequiredFieldBar")
+ .buildDescriptor("NestedFoobarWithRequiredFieldBar", Some(proto2Desc))
+ .descriptor
val nestedFoobarSQLType = structFromDDL(
"struct<nested_foobar: struct<foo string>>" // "bar" field is missing.
)
// serializing with extra fails if required field is missing in inner
struct
assertFailedConversionMessage(
- ProtobufUtils.buildDescriptor(proto2Desc,
"NestedFoobarWithRequiredFieldBar"),
+ ProtobufUtils
+ .buildDescriptor("NestedFoobarWithRequiredFieldBar", Some(proto2Desc))
+ .descriptor,
Serializer,
BY_NAME,
catalystSchema = nestedFoobarSQLType,
@@ -216,9 +224,8 @@ class ProtobufSerdeSuite extends SharedSparkSession with
ProtobufTestBase {
val e1 = intercept[AnalysisException] {
ProtobufUtils.buildDescriptor(
- CommonProtobufUtils.readDescriptorFileContent(fileDescFile),
- "SerdeBasicMessage"
- )
+ "SerdeBasicMessage",
+ Some(CommonProtobufUtils.readDescriptorFileContent(fileDescFile)))
}
checkError(
@@ -234,9 +241,7 @@ class ProtobufSerdeSuite extends SharedSparkSession with
ProtobufTestBase {
val e2 = intercept[AnalysisException] {
- ProtobufUtils.buildDescriptor(
- basicMessageDescWithoutImports,
- "BasicMessage")
+ ProtobufUtils.buildDescriptor("BasicMessage",
Some(basicMessageDescWithoutImports))
}
checkError(
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 495ece401949..b2a2b7027394 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -6896,6 +6896,15 @@ object SQLConf {
.booleanConf
.createWithDefault(Utils.isTesting)
+ val PROTOBUF_EXTENSIONS_SUPPORT_ENABLED =
+ buildConf("spark.sql.function.protobufExtensions.enabled")
+ .doc("When true, the from_protobuf and to_protobuf operators will
support proto2 " +
+ "extensions when a binary file descriptor set is provided. This
property will have no " +
+ "effect for the overloads taking a Java class name instead of a file
descriptor set.")
+ .version("4.2.0")
+ .booleanConf
+ .createWithDefault(false)
+
/**
* Holds information about keys that have been deprecated.
*
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/checkpointing/OffsetSeq.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/checkpointing/OffsetSeq.scala
index 23c266e82891..dff09c736a83 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/checkpointing/OffsetSeq.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/checkpointing/OffsetSeq.scala
@@ -203,7 +203,7 @@ object OffsetSeqMetadata extends Logging {
STATE_STORE_ROCKSDB_FORMAT_VERSION,
STATE_STORE_ROCKSDB_MERGE_OPERATOR_VERSION,
STATEFUL_OPERATOR_USE_STRICT_DISTRIBUTION,
PRUNE_FILTERS_CAN_PRUNE_STREAMING_SUBPLAN,
STREAMING_STATE_STORE_ENCODING_FORMAT,
- STATE_STORE_ROW_CHECKSUM_ENABLED
+ STATE_STORE_ROW_CHECKSUM_ENABLED, PROTOBUF_EXTENSIONS_SUPPORT_ENABLED
)
/**
@@ -251,7 +251,8 @@ object OffsetSeqMetadata extends Logging {
PRUNE_FILTERS_CAN_PRUNE_STREAMING_SUBPLAN.key -> "true",
STREAMING_STATE_STORE_ENCODING_FORMAT.key -> "unsaferow",
STATE_STORE_ROW_CHECKSUM_ENABLED.key -> "false",
- STATE_STORE_ROCKSDB_MERGE_OPERATOR_VERSION.key -> "1"
+ STATE_STORE_ROCKSDB_MERGE_OPERATOR_VERSION.key -> "1",
+ PROTOBUF_EXTENSIONS_SUPPORT_ENABLED.key -> "false"
)
def readValue[T](metadataLog: OffsetSeqMetadataBase, confKey:
ConfigEntry[T]): String = {
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ProtobufExtensionsSupportOffsetLogSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ProtobufExtensionsSupportOffsetLogSuite.scala
new file mode 100644
index 000000000000..0b9914e1dff0
--- /dev/null
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ProtobufExtensionsSupportOffsetLogSuite.scala
@@ -0,0 +1,110 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import org.apache.spark.sql.execution.streaming.checkpointing.{OffsetSeqBase,
OffsetSeqLog, OffsetSeqMetadata}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.test.SharedSparkSession
+
+/**
+ * Tests for the spark.sql.function.protobufExtensions.enabled flag and its
persistence in the
+ * offset log.
+ */
+class ProtobufExtensionsSupportOffsetLogSuite extends SharedSparkSession {
+
+ test("protobuf extensions support disabled by default for new queries") {
+ val offsetSeqMetadata =
+ OffsetSeqMetadata(batchWatermarkMs = 0, batchTimestampMs = 0, spark.conf)
+ assert(
+
offsetSeqMetadata.conf.get(SQLConf.PROTOBUF_EXTENSIONS_SUPPORT_ENABLED.key) ===
+ Some(false.toString))
+ }
+
+ test("protobuf extensions support enabled when session sets it to true") {
+ val protobufExtConf = SQLConf.PROTOBUF_EXTENSIONS_SUPPORT_ENABLED.key
+ withSQLConf(protobufExtConf -> true.toString) {
+ val offsetSeqMetadata =
+ OffsetSeqMetadata(batchWatermarkMs = 0, batchTimestampMs = 0,
spark.conf)
+ assert(offsetSeqMetadata.conf.get(protobufExtConf) ===
Some(true.toString))
+ }
+ }
+
+ test(
+ "protobuf extensions support uses default false for old checkpoint when
enabled in session") {
+ val protobufExtConf = SQLConf.PROTOBUF_EXTENSIONS_SUPPORT_ENABLED.key
+ withSQLConf(protobufExtConf -> true.toString) {
+ val existingChkpt = "offset-log-version-2.1.0"
+ val (_, offsetSeq) = readFromResource(existingChkpt)
+ val offsetSeqMetadata = offsetSeq.metadataOpt.get
+ // Not present in existing checkpoint
+ assert(offsetSeqMetadata.conf.get(protobufExtConf) === None)
+
+ val clonedSqlConf = spark.sessionState.conf.clone()
+ OffsetSeqMetadata.setSessionConf(offsetSeqMetadata, clonedSqlConf)
+
assert(!clonedSqlConf.getConf(SQLConf.PROTOBUF_EXTENSIONS_SUPPORT_ENABLED))
+ }
+ }
+
+ test(
+ "protobuf extensions support uses default false for old checkpoint when
unset in session") {
+ val existingChkpt = "offset-log-version-2.1.0"
+ val (_, offsetSeq) = readFromResource(existingChkpt)
+ val offsetSeqMetadata = offsetSeq.metadataOpt.get
+ val protobufExtConf = SQLConf.PROTOBUF_EXTENSIONS_SUPPORT_ENABLED.key
+ // Not present in existing checkpoint
+ assert(offsetSeqMetadata.conf.get(protobufExtConf) === None)
+
+ val clonedSqlConf = spark.sessionState.conf.clone()
+ OffsetSeqMetadata.setSessionConf(offsetSeqMetadata, clonedSqlConf)
+ assert(!clonedSqlConf.getConf(SQLConf.PROTOBUF_EXTENSIONS_SUPPORT_ENABLED))
+ }
+
+ test("protobuf extensions support in existing checkpoint takes precedence
over session value") {
+ val protobufExtConf = SQLConf.PROTOBUF_EXTENSIONS_SUPPORT_ENABLED.key
+
+ // Enabled when checkpoint = enabled and session = disabled
+ withSQLConf(protobufExtConf -> false.toString) {
+ val offsetSeqMetadata = OffsetSeqMetadata(
+ batchWatermarkMs = 0,
+ batchTimestampMs = 0,
+ Map(protobufExtConf -> true.toString))
+
+ val clonedSqlConf = spark.sessionState.conf.clone()
+ OffsetSeqMetadata.setSessionConf(offsetSeqMetadata, clonedSqlConf)
+
assert(clonedSqlConf.getConf(SQLConf.PROTOBUF_EXTENSIONS_SUPPORT_ENABLED))
+ }
+
+ // Disabled when checkpoint = disabled and session = enabled
+ withSQLConf(protobufExtConf -> true.toString) {
+ val offsetSeqMetadata = OffsetSeqMetadata(
+ batchWatermarkMs = 0,
+ batchTimestampMs = 0,
+ Map(protobufExtConf -> false.toString))
+
+ val clonedSqlConf = spark.sessionState.conf.clone()
+ OffsetSeqMetadata.setSessionConf(offsetSeqMetadata, clonedSqlConf)
+
assert(!clonedSqlConf.getConf(SQLConf.PROTOBUF_EXTENSIONS_SUPPORT_ENABLED))
+ }
+ }
+
+ private def readFromResource(dir: String): (Long, OffsetSeqBase) = {
+ val input = getClass.getResource(s"/structured-streaming/$dir")
+ val log = new OffsetSeqLog(spark, input.toString)
+ log.getLatest().get
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]