This is an automated email from the ASF dual-hosted git repository.

ashrigondekar pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new dd5ce947d808 [SPARK-55062][PROTOBUF] Support proto2 extensions in 
protobuf functions
dd5ce947d808 is described below

commit dd5ce947d80855b3793e5f33e7cf51c593d897e6
Author: David Young <[email protected]>
AuthorDate: Mon Feb 23 10:24:10 2026 -0800

    [SPARK-55062][PROTOBUF] Support proto2 extensions in protobuf functions
    
    ### What changes were proposed in this pull request?
    
    This PR adds support for proto2 extensions to `from_protobuf` and 
`to_protobuf` (when file descriptor set is provided, as Java classes do not 
contain enough information to support extensions).
    
    This is done by building an ExtensionRegistry and a map from descriptor 
name to its extensions. The registry is used during construction of the 
DynamicMessage to provide the Protobuf library with visibility of the 
extensions. The index is plumbed through the various helper classes for use in 
schema conversion and serde.
    
    This new functionality is gated behind the Spark config property 
`spark.sql.function.protobufExtensions.enabled`.
    
    ### Why are the changes needed?
    
    Proto2 extensions are a valid, if somewhat uncommon, feature of Protobuf, 
and it therefore makes sense to incorporate them into the schema when provided 
so as to not confuse the user.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes. Previously, extension fields would be dropped by both `from_protobuf` 
and `to_protobuf`. Now, they are retained. This can be demonstrated with the 
minimal example below. See the unit tests for more examples.
    
    ```proto
    message Person {
        int32 id = 1;
        extensions 100 to 200;
    }
    extend Person {
        int32 age = 100;
    }
    ```
    
    ### How was this patch tested?
    
    Unit tests were added for the new behavior, including basic behavior, 
extending nested messages, and extensions defined in separate files.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    Initial draft authored with Claude Code.
    
    Generated-by: claude-4.5-opus
    
    Closes #53828 from dichlorodiphen/proto-rebase.
    
    Authored-by: David Young <[email protected]>
    Signed-off-by: Anish Shrigondekar <[email protected]>
---
 .../sql/protobuf/CatalystDataToProtobuf.scala      |   9 +-
 .../sql/protobuf/ProtobufDataToCatalyst.scala      |  31 +-
 .../spark/sql/protobuf/ProtobufDeserializer.scala  |  76 ++-
 .../spark/sql/protobuf/ProtobufSerializer.scala    |  53 +-
 .../spark/sql/protobuf/utils/ProtobufUtils.scala   | 203 ++++++--
 .../sql/protobuf/utils/SchemaConverters.scala      |  57 ++-
 .../test/resources/protobuf/proto2_messages.proto  |  69 +++
 .../resources/protobuf/proto2_messages_ext.proto   |  32 ++
 .../ProtobufCatalystDataConversionSuite.scala      |  15 +-
 .../sql/protobuf/ProtobufExtensionsSuite.scala     | 538 +++++++++++++++++++++
 .../sql/protobuf/ProtobufFunctionsSuite.scala      |  64 ++-
 .../spark/sql/protobuf/ProtobufSerdeSuite.scala    |  35 +-
 .../org/apache/spark/sql/internal/SQLConf.scala    |   9 +
 .../streaming/checkpointing/OffsetSeq.scala        |   5 +-
 .../ProtobufExtensionsSupportOffsetLogSuite.scala  | 110 +++++
 15 files changed, 1155 insertions(+), 151 deletions(-)

diff --git 
a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/CatalystDataToProtobuf.scala
 
b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/CatalystDataToProtobuf.scala
index 0564eee1602a..9ed056690d26 100644
--- 
a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/CatalystDataToProtobuf.scala
+++ 
b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/CatalystDataToProtobuf.scala
@@ -35,11 +35,16 @@ private[sql] case class CatalystDataToProtobuf(
 
   override def dataType: DataType = BinaryType
 
-  @transient private lazy val protoDescriptor =
+  @transient private lazy val descriptorWithExtensions =
     ProtobufUtils.buildDescriptor(messageName, binaryFileDescriptorSet)
 
+  @transient private lazy val protoDescriptor = 
descriptorWithExtensions.descriptor
+
+  @transient private lazy val fullNamesToExtensions =
+    descriptorWithExtensions.fullNamesToExtensions
+
   @transient private lazy val serializer =
-    new ProtobufSerializer(child.dataType, protoDescriptor, child.nullable)
+    new ProtobufSerializer(child.dataType, protoDescriptor, child.nullable, 
fullNamesToExtensions)
 
   override def nullSafeEval(input: Any): Any = {
     val dynamicMessage = 
serializer.serialize(input).asInstanceOf[DynamicMessage]
diff --git 
a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDataToCatalyst.scala
 
b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDataToCatalyst.scala
index b3225d61eb01..864eb93a382b 100644
--- 
a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDataToCatalyst.scala
+++ 
b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDataToCatalyst.scala
@@ -40,17 +40,24 @@ private[sql] case class ProtobufDataToCatalyst(
   override def inputTypes: Seq[AbstractDataType] = Seq(BinaryType)
 
   override lazy val dataType: DataType =
-    SchemaConverters.toSqlType(messageDescriptor, protobufOptions).dataType
+    SchemaConverters.toSqlType(messageDescriptor, fullNamesToExtensions, 
protobufOptions).dataType
 
   override def nullable: Boolean = true
 
   private lazy val protobufOptions = ProtobufOptions(options)
 
-  @transient private lazy val messageDescriptor =
+  @transient private lazy val descriptorWithExtensions =
     ProtobufUtils.buildDescriptor(messageName, binaryFileDescriptorSet)
 
-  @transient private lazy val fieldsNumbers =
-    messageDescriptor.getFields.asScala.map(f => f.getNumber).toSet
+  @transient private lazy val messageDescriptor = 
descriptorWithExtensions.descriptor
+
+  @transient private lazy val extensionRegistry = 
descriptorWithExtensions.extensionRegistry
+
+  @transient private lazy val fullNamesToExtensions =
+    descriptorWithExtensions.fullNamesToExtensions
+
+  @transient private lazy val regularFieldNumbers =
+    messageDescriptor.getFields.asScala.map(_.getNumber).toSet
 
   @transient private lazy val deserializer = {
     val typeRegistry = binaryFileDescriptorSet match {
@@ -65,8 +72,8 @@ private[sql] case class ProtobufDataToCatalyst(
       dataType,
       typeRegistry = typeRegistry,
       emitDefaultValues = protobufOptions.emitDefaultValues,
-      enumsAsInts = protobufOptions.enumsAsInts
-    )
+      enumsAsInts = protobufOptions.enumsAsInts,
+      fullNamesToExtensions = fullNamesToExtensions)
   }
 
   @transient private var result: DynamicMessage = _
@@ -92,14 +99,18 @@ private[sql] case class ProtobufDataToCatalyst(
   override def nullSafeEval(input: Any): Any = {
     val binary = input.asInstanceOf[Array[Byte]]
     try {
-      result = DynamicMessage.parseFrom(messageDescriptor, binary)
+      result = DynamicMessage.parseFrom(messageDescriptor, binary, 
extensionRegistry)
       // If the Java class is available, it is likely more efficient to parse 
with it than using
       // DynamicMessage. Can consider it in the future if parsing overhead is 
noticeable.
 
-      
result.getUnknownFields.asMap().keySet().asScala.find(fieldsNumbers.contains(_))
 match {
+      result.getUnknownFields
+        .asMap()
+        .keySet()
+        .asScala
+        .find(regularFieldNumbers.contains(_)) match {
         case Some(number) =>
-          // Unknown fields contain a field with same number as a known field. 
Must be due to
-          // mismatch of schema between writer and reader here.
+          // Unknown fields contain a field with same number as a known 
regular field. Must be due
+          // to mismatch of schema between writer and reader here.
           throw QueryCompilationErrors.protobufFieldTypeMismatchError(
             messageDescriptor.getFields.get(number).toString)
         case None =>
diff --git 
a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDeserializer.scala
 
b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDeserializer.scala
index fa6567ae2aa5..9f2bafb8a180 100644
--- 
a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDeserializer.scala
+++ 
b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDeserializer.scala
@@ -40,7 +40,8 @@ private[sql] class ProtobufDeserializer(
     filters: StructFilters = new NoopFilters,
     typeRegistry: TypeRegistry = TypeRegistry.getEmptyTypeRegistry,
     emitDefaultValues: Boolean = false,
-    enumsAsInts: Boolean = false) {
+    enumsAsInts: Boolean = false,
+    fullNamesToExtensions: Map[String, Seq[FieldDescriptor]] = Map.empty) {
 
   def this(rootDescriptor: Descriptor, rootCatalystType: DataType) = {
     this(
@@ -59,7 +60,8 @@ private[sql] class ProtobufDeserializer(
           val resultRow = new SpecificInternalRow(st.map(_.dataType))
           val fieldUpdater = new RowUpdater(resultRow)
           val applyFilters = filters.skipRow(resultRow, _)
-          val writer = getRecordWriter(rootDescriptor, st, Nil, Nil, 
applyFilters)
+          val writer =
+            getRecordWriter(rootDescriptor, st, Nil, Nil, applyFilters, 
fullNamesToExtensions)
           (data: Any) => {
             val record = data.asInstanceOf[DynamicMessage]
             val skipRow = writer(fieldUpdater, record)
@@ -97,11 +99,18 @@ private[sql] class ProtobufDeserializer(
       protoPath: Seq[String],
       catalystPath: Seq[String],
       elementType: DataType,
-      containsNull: Boolean): (CatalystDataUpdater, Int, Any) => Unit = {
+      containsNull: Boolean,
+      protoFullNamesToExtensions: Map[String, Seq[FieldDescriptor]] = 
Map.empty)
+      : (CatalystDataUpdater, Int, Any) => Unit = {
 
     val protoElementPath = protoPath :+ "element"
     val elementWriter =
-      newWriter(protoField, elementType, protoElementPath, catalystPath :+ 
"element")
+      newWriter(
+        protoField,
+        elementType,
+        protoElementPath,
+        catalystPath :+ "element",
+        protoFullNamesToExtensions)
     (updater, ordinal, value) =>
       val collection = value.asInstanceOf[java.util.Collection[Any]]
       val result = createArrayData(elementType, collection.size())
@@ -133,12 +142,24 @@ private[sql] class ProtobufDeserializer(
       catalystPath: Seq[String],
       keyType: DataType,
       valueType: DataType,
-      valueContainsNull: Boolean): (CatalystDataUpdater, Int, Any) => Unit = {
+      valueContainsNull: Boolean,
+      protoFullNamesToExtensions: Map[String, Seq[FieldDescriptor]] = 
Map.empty)
+      : (CatalystDataUpdater, Int, Any) => Unit = {
     val keyField = protoType.getMessageType.getFields.get(0)
     val valueField = protoType.getMessageType.getFields.get(1)
-    val keyWriter = newWriter(keyField, keyType, protoPath :+ "key", 
catalystPath :+ "key")
+    val keyWriter = newWriter(
+      keyField,
+      keyType,
+      protoPath :+ "key",
+      catalystPath :+ "key",
+      protoFullNamesToExtensions)
     val valueWriter =
-      newWriter(valueField, valueType, protoPath :+ "value", catalystPath :+ 
"value")
+      newWriter(
+        valueField,
+        valueType,
+        protoPath :+ "value",
+        catalystPath :+ "value",
+        protoFullNamesToExtensions)
     (updater, ordinal, value) =>
       if (value != null) {
         val messageList = 
value.asInstanceOf[java.util.List[com.google.protobuf.Message]]
@@ -174,7 +195,9 @@ private[sql] class ProtobufDeserializer(
       protoType: FieldDescriptor,
       catalystType: DataType,
       protoPath: Seq[String],
-      catalystPath: Seq[String]): (CatalystDataUpdater, Int, Any) => Unit = {
+      catalystPath: Seq[String],
+      protoFullNamesToExtensions: Map[String, Seq[FieldDescriptor]] = 
Map.empty)
+      : (CatalystDataUpdater, Int, Any) => Unit = {
 
     (protoType.getJavaType, catalystType) match {
 
@@ -201,7 +224,13 @@ private[sql] class ProtobufDeserializer(
       case  (
         MESSAGE | BOOLEAN | INT | FLOAT | DOUBLE | LONG | STRING | ENUM | 
BYTE_STRING,
         ArrayType(dataType: DataType, containsNull)) if protoType.isRepeated =>
-        newArrayWriter(protoType, protoPath, catalystPath, dataType, 
containsNull)
+        newArrayWriter(
+          protoType,
+          protoPath,
+          catalystPath,
+          dataType,
+          containsNull,
+          protoFullNamesToExtensions)
 
       case (LONG, LongType) =>
         (updater, ordinal, value) => updater.setLong(ordinal, 
value.asInstanceOf[Long])
@@ -236,7 +265,14 @@ private[sql] class ProtobufDeserializer(
           updater.set(ordinal, byte_array)
 
       case (MESSAGE, MapType(keyType, valueType, valueContainsNull)) =>
-        newMapWriter(protoType, protoPath, catalystPath, keyType, valueType, 
valueContainsNull)
+        newMapWriter(
+          protoType,
+          protoPath,
+          catalystPath,
+          keyType,
+          valueType,
+          valueContainsNull,
+          protoFullNamesToExtensions)
 
       case (MESSAGE, TimestampType) =>
         (updater, ordinal, value) =>
@@ -368,7 +404,8 @@ private[sql] class ProtobufDeserializer(
           st,
           protoPath,
           catalystPath,
-          applyFilters = _ => false)
+          applyFilters = _ => false,
+          protoFullNamesToExtensions)
         (updater, ordinal, value) =>
           val row = new SpecificInternalRow(st)
           writeRecord(new RowUpdater(row), value.asInstanceOf[DynamicMessage])
@@ -399,10 +436,20 @@ private[sql] class ProtobufDeserializer(
       catalystType: StructType,
       protoPath: Seq[String],
       catalystPath: Seq[String],
-      applyFilters: Int => Boolean): (CatalystDataUpdater, DynamicMessage) => 
Boolean = {
+      applyFilters: Int => Boolean,
+      protoFullNamesToExtensions: Map[String, Seq[FieldDescriptor]] = 
Map.empty)
+      : (CatalystDataUpdater, DynamicMessage) => Boolean = {
 
+    // Get extension fields for this specific message type
+    val protoExtensionFields =
+      protoFullNamesToExtensions.getOrElse(protoType.getFullName, Seq.empty)
     val protoSchemaHelper =
-      new ProtobufUtils.ProtoSchemaHelper(protoType, catalystType, protoPath, 
catalystPath)
+      new ProtobufUtils.ProtoSchemaHelper(
+        protoType,
+        catalystType,
+        protoPath,
+        catalystPath,
+        protoExtensionFields)
 
     // TODO revisit validation of protobuf-catalyst fields.
     // protoSchemaHelper.validateNoExtraCatalystFields(ignoreNullable = true)
@@ -414,7 +461,8 @@ private[sql] class ProtobufDeserializer(
           protoField,
           catalystField.dataType,
           protoPath :+ protoField.getName,
-          catalystPath :+ catalystField.name)
+          catalystPath :+ catalystField.name,
+          protoFullNamesToExtensions)
         val fieldWriter = (fieldUpdater: CatalystDataUpdater, value: Any) => {
           if (value == null) {
             fieldUpdater.setNullAt(ordinal)
diff --git 
a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufSerializer.scala
 
b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufSerializer.scala
index 65e8cce0d056..cd638f3ac12a 100644
--- 
a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufSerializer.scala
+++ 
b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufSerializer.scala
@@ -39,7 +39,8 @@ import org.apache.spark.sql.types._
 private[sql] class ProtobufSerializer(
     rootCatalystType: DataType,
     rootDescriptor: Descriptor,
-    nullable: Boolean)
+    nullable: Boolean,
+    fullNamesToExtensions: Map[String, Seq[FieldDescriptor]] = Map.empty)
     extends Logging {
 
   def serialize(catalystData: Any): Any = {
@@ -54,7 +55,8 @@ private[sql] class ProtobufSerializer(
       try {
         rootCatalystType match {
           case st: StructType =>
-            newStructConverter(st, rootDescriptor, Nil, Nil).asInstanceOf[Any 
=> Any]
+            newStructConverter(st, rootDescriptor, Nil, Nil, 
fullNamesToExtensions)
+              .asInstanceOf[Any => Any]
         }
       } catch {
         case ise: AnalysisException =>
@@ -80,7 +82,8 @@ private[sql] class ProtobufSerializer(
       catalystType: DataType,
       fieldDescriptor: FieldDescriptor,
       catalystPath: Seq[String],
-      protoPath: Seq[String]): Converter = {
+      protoPath: Seq[String],
+      protoFullNamesToExtensions: Map[String, Seq[FieldDescriptor]]): 
Converter = {
     (catalystType, fieldDescriptor.getJavaType) match {
       case (NullType, _) =>
         (getter, ordinal) => null
@@ -157,7 +160,12 @@ private[sql] class ProtobufSerializer(
 
       case (ArrayType(et, containsNull), _) =>
         val elementConverter =
-          newConverter(et, fieldDescriptor, catalystPath :+ "element", 
protoPath :+ "element")
+          newConverter(
+            et,
+            fieldDescriptor,
+            catalystPath :+ "element",
+            protoPath :+ "element",
+            protoFullNamesToExtensions)
         (getter, ordinal) => {
           val arrayData = getter.getArray(ordinal)
           val len = arrayData.numElements()
@@ -224,7 +232,12 @@ private[sql] class ProtobufSerializer(
 
       case (st: StructType, MESSAGE) =>
         val structConverter =
-          newStructConverter(st, fieldDescriptor.getMessageType, catalystPath, 
protoPath)
+          newStructConverter(
+            st,
+            fieldDescriptor.getMessageType,
+            catalystPath,
+            protoPath,
+            protoFullNamesToExtensions)
         val numFields = st.length
         (getter, ordinal) => structConverter(getter.getStruct(ordinal, 
numFields))
 
@@ -240,9 +253,19 @@ private[sql] class ProtobufSerializer(
           }
         }
 
-        val keyConverter = newConverter(kt, keyField, catalystPath :+ "key", 
protoPath :+ "key")
+        val keyConverter = newConverter(
+          kt,
+          keyField,
+          catalystPath :+ "key",
+          protoPath :+ "key",
+          protoFullNamesToExtensions)
         val valueConverter =
-          newConverter(vt, valueField, catalystPath :+ "value", protoPath :+ 
"value")
+          newConverter(
+            vt,
+            valueField,
+            catalystPath :+ "value",
+            protoPath :+ "value",
+            protoFullNamesToExtensions)
 
         (getter, ordinal) =>
           val mapData = getter.getMap(ordinal)
@@ -301,10 +324,19 @@ private[sql] class ProtobufSerializer(
       catalystStruct: StructType,
       descriptor: Descriptor,
       catalystPath: Seq[String],
-      protoPath: Seq[String]): InternalRow => DynamicMessage = {
+      protoPath: Seq[String],
+      protoFullNamesToExtensions: Map[String, Seq[FieldDescriptor]])
+      : InternalRow => DynamicMessage = {
 
+    val protoExtensionFields =
+      protoFullNamesToExtensions.getOrElse(descriptor.getFullName, Seq.empty)
     val protoSchemaHelper =
-      new ProtobufUtils.ProtoSchemaHelper(descriptor, catalystStruct, 
protoPath, catalystPath)
+      new ProtobufUtils.ProtoSchemaHelper(
+        descriptor,
+        catalystStruct,
+        protoPath,
+        catalystPath,
+        protoExtensionFields)
 
     protoSchemaHelper.validateNoExtraCatalystFields(ignoreNullable = false)
     protoSchemaHelper.validateNoExtraRequiredProtoFields()
@@ -315,7 +347,8 @@ private[sql] class ProtobufSerializer(
           catalystField.dataType,
           protoField,
           catalystPath :+ catalystField.name,
-          protoPath :+ protoField.getName)
+          protoPath :+ protoField.getName,
+          protoFullNamesToExtensions)
         (protoField, converter)
       }
       .toArray
diff --git 
a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufUtils.scala
 
b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufUtils.scala
index 3d7bba7a82e8..47898f73d016 100644
--- 
a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufUtils.scala
+++ 
b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufUtils.scala
@@ -19,11 +19,14 @@ package org.apache.spark.sql.protobuf.utils
 
 import java.util.Locale
 
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
 import scala.jdk.CollectionConverters._
 
-import com.google.protobuf.{DescriptorProtos, Descriptors, 
InvalidProtocolBufferException, Message}
+import com.google.protobuf.{DescriptorProtos, Descriptors, DynamicMessage, 
ExtensionRegistry, InvalidProtocolBufferException, Message}
 import com.google.protobuf.DescriptorProtos.{FileDescriptorProto, 
FileDescriptorSet}
 import com.google.protobuf.Descriptors.{Descriptor, FieldDescriptor}
+import com.google.protobuf.Descriptors.FieldDescriptor.JavaType
 import com.google.protobuf.TypeRegistry
 
 import org.apache.spark.internal.Logging
@@ -40,6 +43,21 @@ private[sql] object ProtobufUtils extends Logging {
       catalystPosition: Int,
       fieldDescriptor: FieldDescriptor)
 
+  /**
+   * Container for a Protobuf descriptor and associated extensions.
+   *
+   * @param descriptor
+   *   The descriptor for the top-level message.
+   * @param extensionRegistry
+   *   An extension registry populated with all discovered extensions.
+   * @param fullNamesToExtensions
+   *   A map from message full names to their extension fields.
+   */
+  private[sql] case class DescriptorWithExtensions(
+      descriptor: Descriptor,
+      extensionRegistry: ExtensionRegistry,
+      fullNamesToExtensions: Map[String, Seq[FieldDescriptor]])
+
   /**
    * Helper class to perform field lookup/matching on Protobuf schemas.
    *
@@ -55,20 +73,25 @@ private[sql] object ProtobufUtils extends Logging {
    *   The seq of parent field names leading to `protoSchema`.
    * @param catalystPath
    *   The seq of parent field names leading to `catalystSchema`.
+   * @param extensionFields
+   *   Optional sequence of extension fields to include in field matching.
    */
   class ProtoSchemaHelper(
       descriptor: Descriptor,
       catalystSchema: StructType,
       protoPath: Seq[String],
-      catalystPath: Seq[String]) {
+      catalystPath: Seq[String],
+      extensionFields: Seq[FieldDescriptor] = Seq.empty) {
     if (descriptor.getName == null) {
       throw QueryCompilationErrors.unknownProtobufMessageTypeError(
         descriptor.getName,
         descriptor.getContainingType().getName)
     }
 
-    private[this] val protoFieldArray = descriptor.getFields.asScala.toArray
-    private[this] val fieldMap = descriptor.getFields.asScala
+    // Combine regular fields with extension fields
+    private[this] val protoFieldArray =
+      (descriptor.getFields.asScala ++ extensionFields).toArray
+    private[this] val fieldMap = protoFieldArray
       .groupBy(_.getName.toLowerCase(Locale.ROOT))
       .transform((_, v) => v.toSeq) // toSeq needed for scala 2.13
 
@@ -135,27 +158,34 @@ private[sql] object ProtobufUtils extends Logging {
   }
 
   /**
-   * Builds Protobuf message descriptor either from the Java class or from 
serialized descriptor
-   * read from the file.
+   * Builds a Protobuf descriptor along with an ExtensionRegistry containing 
all extensions found
+   * in the FileDescriptorSet.
+   *
    * @param messageName
-   *  Protobuf message name or Java class name (when binaryFileDescriptorSet 
is None)..
+   *   Protobuf message name or Java class name (when binaryFileDescriptorSet 
is None).
    * @param binaryFileDescriptorSet
-   *  When the binary `FileDescriptorSet` is provided, the descriptor and its 
dependencies are
-   *  read from it.
+   *   When the binary `FileDescriptorSet` is provided, the descriptor, 
extensions, and registry
+   *   are read from it. When None, the descriptor is loaded from the Java 
class and extensions
+   *   are empty.
    * @return
+   *   DescriptorWithExtensions containing descriptor, registry, and extension 
fields map
    */
-  def buildDescriptor(messageName: String, binaryFileDescriptorSet: 
Option[Array[Byte]])
-  : Descriptor = {
+  def buildDescriptor(
+      messageName: String,
+      binaryFileDescriptorSet: Option[Array[Byte]]): DescriptorWithExtensions =
     binaryFileDescriptorSet match {
-      case Some(bytes) => buildDescriptor(bytes, messageName)
+      case Some(bytes) => buildDescriptorFromFDS(messageName, bytes)
       case None => buildDescriptorFromJavaClass(messageName)
     }
-  }
 
   /**
-   *  Loads the given protobuf class and returns Protobuf descriptor for it.
+   * Loads the given protobuf class and returns Protobuf descriptor for it.
+   *
+   * Given a Java class, we can only access the descriptor for the proto file 
it is defined
+   * in. Extensions in other files will not be picked up. As such, we choose 
to disable
+   * extension support when we fall back to the Java class.
    */
-  def buildDescriptorFromJavaClass(protobufClassName: String): Descriptor = {
+  def buildDescriptorFromJavaClass(protobufClassName: String): 
DescriptorWithExtensions = {
 
     // Default 'Message' class here is shaded while using the package (as in 
production).
     // The incoming classes might not be shaded. Check both.
@@ -203,23 +233,32 @@ private[sql] object ProtobufUtils extends Logging {
           protobufClassName, "Could not find getDescriptor() method", e)
     }
 
-    getDescriptorMethod
-      .invoke(null)
-      .asInstanceOf[Descriptor]
+    val descriptor = getDescriptorMethod.invoke(null).asInstanceOf[Descriptor]
+
+    DescriptorWithExtensions(descriptor, ExtensionRegistry.getEmptyRegistry, 
Map.empty)
   }
 
-  def buildDescriptor(binaryFileDescriptorSet: Array[Byte], messageName: 
String): Descriptor = {
+  def buildDescriptorFromFDS(
+      messageName: String,
+      binaryFileDescriptorSet: Array[Byte]): DescriptorWithExtensions = {
     // Find the first message descriptor that matches the name.
-    val descriptorOpt = parseFileDescriptorSet(binaryFileDescriptorSet)
+    val fileDescriptors = parseFileDescriptorSet(binaryFileDescriptorSet)
+    val descriptor = fileDescriptors
       .flatMap { fileDesc =>
         fileDesc.getMessageTypes.asScala.find { desc =>
           desc.getName == messageName || desc.getFullName == messageName
         }
-      }.headOption
+      }
+      .headOption
+      .getOrElse {
+        throw 
QueryCompilationErrors.unableToLocateProtobufMessageError(messageName)
+      }
+    val (extensionRegistry, fullNamesToExtensions) = 
buildExtensionRegistry(fileDescriptors)
 
-    descriptorOpt match {
-      case Some(d) => d
-      case None => throw 
QueryCompilationErrors.unableToLocateProtobufMessageError(messageName)
+    if (SQLConf.get.getConf(SQLConf.PROTOBUF_EXTENSIONS_SUPPORT_ENABLED)) {
+      DescriptorWithExtensions(descriptor, extensionRegistry, 
fullNamesToExtensions)
+    } else {
+      DescriptorWithExtensions(descriptor, ExtensionRegistry.getEmptyRegistry, 
Map.empty)
     }
   }
 
@@ -232,38 +271,49 @@ private[sql] object ProtobufUtils extends Logging {
         throw QueryCompilationErrors.descriptorParseError(ex)
     }
     val fileDescriptorProtoIndex = createDescriptorProtoMap(fileDescriptorSet)
+
+    // Mutated across invocations of buildFileDescriptor.
+    val builtDescriptors = mutable.Map[String, Descriptors.FileDescriptor]()
     val fileDescriptorList: List[Descriptors.FileDescriptor] =
-      fileDescriptorSet.getFileList.asScala.map( fileDescriptorProto =>
-        buildFileDescriptor(fileDescriptorProto, fileDescriptorProtoIndex)
-      ).toList
+      fileDescriptorSet.getFileList.asScala.map { fileDescriptorProto =>
+        buildFileDescriptor(fileDescriptorProto, fileDescriptorProtoIndex, 
builtDescriptors)
+      }.distinctBy(_.getFullName).toList
     fileDescriptorList
   }
 
   /**
-   * Recursively constructs file descriptors for all dependencies for given
-   * FileDescriptorProto and return.
+   * Recursively constructs file descriptors for all dependencies for given 
FileDescriptorProto
+   * and return.
    */
   private def buildFileDescriptor(
-    fileDescriptorProto: FileDescriptorProto,
-    fileDescriptorProtoMap: Map[String, FileDescriptorProto]): 
Descriptors.FileDescriptor = {
-    val fileDescriptorList = 
fileDescriptorProto.getDependencyList().asScala.map { dependency =>
-      fileDescriptorProtoMap.get(dependency) match {
-        case Some(dependencyProto) =>
-          if (dependencyProto.getName == "google/protobuf/any.proto"
-            && dependencyProto.getPackage == "google.protobuf") {
-            // For Any, use the descriptor already included as part of the 
Java dependency.
-            // Without this, JsonFormat used for converting Any fields fails 
when
-            // an Any field in input is set to `Any.getDefaultInstance()`.
-            com.google.protobuf.AnyProto.getDescriptor
-            // Should we do the same for timestamp.proto and empty.proto?
-          } else {
-            buildFileDescriptor(dependencyProto, fileDescriptorProtoMap)
-          }
-        case None =>
-          throw 
QueryCompilationErrors.protobufDescriptorDependencyError(dependency)
-      }
-    }
-    Descriptors.FileDescriptor.buildFrom(fileDescriptorProto, 
fileDescriptorList.toArray)
+      fileDescriptorProto: FileDescriptorProto,
+      fileDescriptorProtoMap: Map[String, FileDescriptorProto],
+      builtDescriptors: mutable.Map[String, Descriptors.FileDescriptor])
+      : Descriptors.FileDescriptor = {
+    // Storing references to constructed descriptors is crucial because 
descriptors are compared
+    // by reference inside in the Protobuf library.
+    builtDescriptors.getOrElseUpdate(
+      fileDescriptorProto.getName, {
+        val fileDescriptorList = 
fileDescriptorProto.getDependencyList().asScala.map {
+          dependency =>
+            fileDescriptorProtoMap.get(dependency) match {
+              case Some(dependencyProto) =>
+                if (dependencyProto.getName == "google/protobuf/any.proto"
+                  && dependencyProto.getPackage == "google.protobuf") {
+                  // For Any, use the descriptor already included as part of 
the Java dependency.
+                  // Without this, JsonFormat used for converting Any fields 
fails when
+                  // an Any field in input is set to 
`Any.getDefaultInstance()`.
+                  com.google.protobuf.AnyProto.getDescriptor
+                  // Should we do the same for timestamp.proto and empty.proto?
+                } else {
+                  buildFileDescriptor(dependencyProto, fileDescriptorProtoMap, 
builtDescriptors)
+                }
+              case None =>
+                throw 
QueryCompilationErrors.protobufDescriptorDependencyError(dependency)
+            }
+        }
+        Descriptors.FileDescriptor.buildFrom(fileDescriptorProto, 
fileDescriptorList.toArray)
+      })
   }
 
   /**
@@ -296,9 +346,60 @@ private[sql] object ProtobufUtils extends Logging {
   }
 
   /** Builds [[TypeRegistry]] with the descriptor and the others from the same 
proto file. */
-  private [protobuf] def buildTypeRegistry(descriptor: Descriptor): 
TypeRegistry = {
-    TypeRegistry.newBuilder()
+  private[protobuf] def buildTypeRegistry(descriptor: Descriptor): 
TypeRegistry = {
+    TypeRegistry
+      .newBuilder()
       .add(descriptor) // This adds any other descriptors in the associated 
proto file.
       .build()
   }
+
+  /**
+   * Builds an ExtensionRegistry and an index from full name to field 
descriptor for all extensions
+   * found in the list of provided file descriptors.
+   *
+   * This method will traverse the AST to ensure extensions in nested scopes 
are registered as well.
+   *
+   * @param fileDescriptors
+   *   List of all file descriptors to process
+   * @return
+   *   The populated ExtensionRegistry and a map from message full names to 
extension fields
+   *   sorted by field number
+   */
+  private def buildExtensionRegistry(fileDescriptors: 
List[Descriptors.FileDescriptor])
+      : (ExtensionRegistry, Map[String, Seq[FieldDescriptor]]) = {
+    val registry = ExtensionRegistry.newInstance()
+    val fullNameToExtensions = mutable.Map[String, 
ArrayBuffer[FieldDescriptor]]()
+
+    // Adds an extension to both the registry and map.
+    def addExtension(extField: FieldDescriptor): Unit = {
+      val extendeeName = extField.getContainingType.getFullName
+      // For message-type extensions, we need to provide a default instance.
+      if (extField.getJavaType == JavaType.MESSAGE) {
+        val defaultInstance = 
DynamicMessage.getDefaultInstance(extField.getMessageType)
+        registry.add(extField, defaultInstance)
+      } else {
+        registry.add(extField)
+      }
+      fullNameToExtensions
+        .getOrElseUpdate(extendeeName, mutable.ArrayBuffer())
+        .append(extField)
+    }
+
+    for (fileDesc <- fileDescriptors) {
+      fileDesc.getExtensions.asScala.foreach(addExtension)
+
+      // Recursively add nested extensions.
+      def collectNestedExtensions(msgDesc: Descriptor): Unit = {
+        msgDesc.getExtensions.asScala.foreach(addExtension)
+        msgDesc.getNestedTypes.asScala.foreach(collectNestedExtensions)
+      }
+      fileDesc.getMessageTypes.asScala.foreach(collectNestedExtensions)
+    }
+
+    // Sort extension fields by field number for consistent ordering.
+    val sortedMap = fullNameToExtensions.map { case (name, extensions) =>
+      name -> extensions.sortBy(_.getNumber).toSeq
+    }.toMap
+    (registry, sortedMap)
+  }
 }
diff --git 
a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/SchemaConverters.scala
 
b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/SchemaConverters.scala
index 3273f473ed5c..5a73d9aaa72b 100644
--- 
a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/SchemaConverters.scala
+++ 
b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/SchemaConverters.scala
@@ -38,34 +38,43 @@ object SchemaConverters extends Logging {
   /**
    * Converts an Protobuf schema to a corresponding Spark SQL schema.
    *
+   * fullNamesToExtensions is a map from full message names to the field 
descriptors of their
+   * extensions. It is non-empty when proto2 extension support is enabled.
+   *
    * @since 3.4.0
    */
   private[protobuf] def toSqlType(
       descriptor: Descriptor,
+      fullNamesToExtensions: Map[String, Seq[FieldDescriptor]] = Map.empty,
       protobufOptions: ProtobufOptions = ProtobufOptions(Map.empty)): 
SchemaType = {
-    toSqlTypeHelper(descriptor, protobufOptions)
-  }
+    val existingRecordNames = Map(descriptor.getFullName -> 1)
+    val regularFields = descriptor.getFields.asScala
+      .flatMap(structFieldFor(_, existingRecordNames, fullNamesToExtensions, 
protobufOptions))
+      .toSeq
+    val extensionFields = 
fullNamesToExtensions.getOrElse(descriptor.getFullName, Seq.empty)
+    val extStructFields = extensionFields.flatMap(
+      structFieldFor(_, existingRecordNames, fullNamesToExtensions, 
protobufOptions))
 
-  private[protobuf] def toSqlTypeHelper(
-      descriptor: Descriptor,
-      protobufOptions: ProtobufOptions): SchemaType = {
-    val fields = descriptor.getFields.asScala.flatMap(
-      structFieldFor(_,
-        Map(descriptor.getFullName -> 1),
-        protobufOptions: ProtobufOptions)).toSeq
-    if (fields.isEmpty && protobufOptions.retainEmptyMessage) {
+    val allFields = regularFields ++ extStructFields
+
+    if (allFields.isEmpty && protobufOptions.retainEmptyMessage) {
       
SchemaType(convertEmptyProtoToStructWithDummyField(descriptor.getFullName), 
nullable = true)
-    } else SchemaType(StructType(fields), nullable = true)
+    } else {
+      SchemaType(StructType(allFields), nullable = true)
+    }
   }
 
   // existingRecordNames: Map[String, Int] used to track the depth of 
recursive fields and to
   // ensure that the conversion of the protobuf message to a Spark SQL 
StructType object does not
   // exceed the maximum recursive depth specified by the 
recursiveFieldMaxDepth option.
+  // fullNamesToExtensions: a map from Protobuf message full names to the 
field descriptors of
+  // their extensions used to add extension fields to the generated 
StructField.
   // A return of None implies the field has reached the maximum allowed 
recursive depth and
   // should be dropped.
   private def structFieldFor(
       fd: FieldDescriptor,
       existingRecordNames: Map[String, Int],
+      fullNamesToExtensions: Map[String, Seq[FieldDescriptor]],
       protobufOptions: ProtobufOptions): Option[StructField] = {
     import com.google.protobuf.Descriptors.FieldDescriptor.JavaType._
 
@@ -150,16 +159,12 @@ object SchemaConverters extends Logging {
           field.getName match {
             case "key" =>
               keyType =
-                structFieldFor(
-                  field,
-                  existingRecordNames,
-                  protobufOptions).map(_.dataType)
+                structFieldFor(field, existingRecordNames, 
fullNamesToExtensions, protobufOptions)
+                  .map(_.dataType)
             case "value" =>
               valueType =
-                structFieldFor(
-                  field,
-                  existingRecordNames,
-                  protobufOptions).map(_.dataType)
+                structFieldFor(field, existingRecordNames, 
fullNamesToExtensions, protobufOptions)
+                  .map(_.dataType)
           }
         }
         (keyType, valueType) match {
@@ -206,10 +211,16 @@ object SchemaConverters extends Logging {
           None
         } else {
           val newRecordNames = existingRecordNames + (recordName -> 
(recursiveDepth + 1))
-          val fields = fd.getMessageType.getFields.asScala.flatMap(
-            structFieldFor(_, newRecordNames, protobufOptions)
-          ).toSeq
-          fields match {
+          // Get regular fields for the nested message
+          val regularFields = fd.getMessageType.getFields.asScala
+            .flatMap(structFieldFor(_, newRecordNames, fullNamesToExtensions, 
protobufOptions))
+            .toSeq
+          // Get extension fields for the nested message
+          val nestedExtensions = fullNamesToExtensions.getOrElse(recordName, 
Seq.empty)
+          val extFields = nestedExtensions.flatMap(
+            structFieldFor(_, newRecordNames, fullNamesToExtensions, 
protobufOptions))
+          val allFields = regularFields ++ extFields
+          allFields match {
             case Nil =>
               if (protobufOptions.retainEmptyMessage) {
                 Some(convertEmptyProtoToStructWithDummyField(fd.getFullName))
diff --git 
a/connector/protobuf/src/test/resources/protobuf/proto2_messages.proto 
b/connector/protobuf/src/test/resources/protobuf/proto2_messages.proto
index 75525715a3d5..2d909bd4d8bd 100644
--- a/connector/protobuf/src/test/resources/protobuf/proto2_messages.proto
+++ b/connector/protobuf/src/test/resources/protobuf/proto2_messages.proto
@@ -55,3 +55,72 @@ message Proto2AllTypes {
   }
   map<string, string> map = 13;
 }
+
+// Message with extension range for testing proto2 extension support
+message Proto2ExtensionTest {
+  optional string name = 1;
+  optional int32 id = 2;
+
+  extensions 100 to 200;
+}
+
+// Identical to Proto2ExtensionTest but without any extensions.
+// Used for schema evolution testing.
+message Proto2ExtensionTestBase {
+  optional string name = 1;
+  optional int32 id = 2;
+}
+
+// File-level extensions for Proto2ExtensionTest
+extend Proto2ExtensionTest {
+  optional string extension_string = 100;
+  optional int32 extension_int = 101;
+  optional NestedExtensionMessage extension_message = 102;
+  repeated int32 extension_repeated_int = 103;
+  optional Proto2AllTypes.NestedEnum extension_enum = 104;
+}
+
+// Nested message type used in an extension field
+message NestedExtensionMessage {
+  optional string nested_value = 1;
+  optional int32 nested_id = 2;
+}
+
+message ContainerWithNestedExtension {
+  optional string container_name = 1;
+
+  // Extension within the scope of another message.
+  extend Proto2ExtensionTest {
+    optional bool nested_extension_bool = 150;
+  }
+}
+
+// Test nested message extensions: a top-level message containing a nested
+// message type that itself has extensions.
+message MessageWithExtendableNested {
+  optional string top_level_name = 1;
+  optional ExtendableNestedMessage nested = 2;
+}
+
+message ExtendableNestedMessage {
+  optional string nested_name = 1;
+  optional int32 nested_id = 2;
+
+  extensions 100 to 200;
+}
+
+extend ExtendableNestedMessage {
+  optional string nested_ext_field = 100;
+  optional int32 nested_ext_int = 101;
+}
+
+// Message for testing extension field name collision
+message Proto2ExtensionCollisionTest {
+  optional string name = 1;
+  extensions 100 to 200;
+}
+
+// Extension that collides with "name" field (case-insensitive)
+extend Proto2ExtensionCollisionTest {
+  optional string NAME = 100;
+}
diff --git 
a/connector/protobuf/src/test/resources/protobuf/proto2_messages_ext.proto 
b/connector/protobuf/src/test/resources/protobuf/proto2_messages_ext.proto
new file mode 100644
index 000000000000..40ae8008d544
--- /dev/null
+++ b/connector/protobuf/src/test/resources/protobuf/proto2_messages_ext.proto
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+syntax = "proto2";
+
+package org.apache.spark.sql.protobuf.protos;
+
+import "proto2_messages.proto";
+
+option java_outer_classname = "Proto2MessagesExt";
+
+message Proto2ExtensionCrossFile {
+    optional int32 foo = 1;
+
+    extend Proto2ExtensionTest {
+        optional Proto2ExtensionCrossFile cross_file_extension = 200;
+    }
+}
diff --git 
a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufCatalystDataConversionSuite.scala
 
b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufCatalystDataConversionSuite.scala
index 1802bceee1df..5f89138d7248 100644
--- 
a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufCatalystDataConversionSuite.scala
+++ 
b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufCatalystDataConversionSuite.scala
@@ -152,7 +152,8 @@ class ProtobufCatalystDataConversionSuite
       expected: Option[Any],
       filters: StructFilters = new NoopFilters): Unit = {
 
-    val descriptor = ProtobufUtils.buildDescriptor(descFileBytes, messageName)
+    val descriptor =
+      ProtobufUtils.buildDescriptor(messageName, 
Some(descFileBytes)).descriptor
     val dataType = SchemaConverters.toSqlType(descriptor).dataType
 
     val deserializer = new ProtobufDeserializer(descriptor, dataType, filters)
@@ -163,6 +164,7 @@ class ProtobufCatalystDataConversionSuite
     // Verify Java class deserializer matches with descriptor based serializer.
     val javaDescriptor = ProtobufUtils
       .buildDescriptorFromJavaClass(s"$javaClassNamePrefix$messageName")
+      .descriptor
     assert(dataType == SchemaConverters.toSqlType(javaDescriptor).dataType)
     val javaDeserialized = new ProtobufDeserializer(javaDescriptor, dataType, 
filters)
       .deserialize(DynamicMessage.parseFrom(javaDescriptor, data.toByteArray))
@@ -199,7 +201,8 @@ class ProtobufCatalystDataConversionSuite
       .add("name", "string")
       .add("age", "int")
 
-    val descriptor = ProtobufUtils.buildDescriptor(testFileDesc, "Person")
+    val descriptor =
+      ProtobufUtils.buildDescriptor("Person", Some(testFileDesc)).descriptor
     val dynamicMessage = DynamicMessage
       .newBuilder(descriptor)
       .setField(descriptor.findFieldByName("name"), "Maxim")
@@ -237,11 +240,13 @@ class ProtobufCatalystDataConversionSuite
   }
 
   test("Full names for message using descriptor file") {
-    val withShortName = ProtobufUtils.buildDescriptor(testFileDesc, "BytesMsg")
+    val withShortName =
+      ProtobufUtils.buildDescriptor("BytesMsg", Some(testFileDesc)).descriptor
     assert(withShortName.findFieldByName("bytes_type") != null)
 
-    val withFullName = ProtobufUtils.buildDescriptor(
-      testFileDesc, "org.apache.spark.sql.protobuf.protos.BytesMsg")
+    val withFullName = ProtobufUtils
+      .buildDescriptor("org.apache.spark.sql.protobuf.protos.BytesMsg", 
Some(testFileDesc))
+      .descriptor
     assert(withFullName.findFieldByName("bytes_type") != null)
   }
 
diff --git 
a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufExtensionsSuite.scala
 
b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufExtensionsSuite.scala
new file mode 100644
index 000000000000..10c86c556cc5
--- /dev/null
+++ 
b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufExtensionsSuite.scala
@@ -0,0 +1,538 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.protobuf
+
+import com.google.protobuf.DescriptorProtos.FileDescriptorSet
+import com.google.protobuf.DynamicMessage
+
+import org.apache.spark.SparkException
+import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
+import 
org.apache.spark.sql.internal.SQLConf.PROTOBUF_EXTENSIONS_SUPPORT_ENABLED
+import org.apache.spark.sql.protobuf.utils.ProtobufUtils
+import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.sql.types._
+import org.apache.spark.sql.util.{ProtobufUtils => CommonProtobufUtils}
+
+class ProtobufExtensionsSuite
+    extends QueryTest
+    with SharedSparkSession
+    with ProtobufTestBase
+    with Serializable {
+
+  import testImplicits._
+
+  val proto2FileDescFile = protobufDescriptorFile("proto2_messages.desc")
+  val proto2FileDesc = 
CommonProtobufUtils.readDescriptorFileContent(proto2FileDescFile)
+  private val proto2JavaClassNamePrefix = 
"org.apache.spark.sql.protobuf.protos.Proto2Messages$"
+
+  test("SPARK-55062: roundtrip - proto2 extension basic types") {
+    withExtensionSupport {
+      val descWithExt = ProtobufUtils.buildDescriptor("Proto2ExtensionTest", 
Some(proto2FileDesc))
+      val descriptor = descWithExt.descriptor
+      val extensionFields = 
descWithExt.fullNamesToExtensions(descriptor.getFullName)
+
+      val extStringField = extensionFields.find(_.getName == 
"extension_string").get
+      val extIntField = extensionFields.find(_.getName == "extension_int").get
+
+      val message = DynamicMessage
+        .newBuilder(descriptor)
+        .setField(descriptor.findFieldByName("name"), "test_name")
+        .setField(descriptor.findFieldByName("id"), 42)
+        .setField(extStringField, "ext_value")
+        .setField(extIntField, 123)
+        .build()
+
+      val df = Seq(message.toByteArray).toDF("value")
+
+      val fromProtoDF = df.select(
+        functions.from_protobuf($"value", "Proto2ExtensionTest", 
proto2FileDesc).as("value_from"))
+
+      // Verify extension field values are correct
+      val row = fromProtoDF.select($"value_from.*").collect().head
+      assert(row.getAs[String]("name") == "test_name")
+      assert(row.getAs[Int]("id") == 42)
+      assert(row.getAs[String]("extension_string") == "ext_value")
+      assert(row.getAs[Int]("extension_int") == 123)
+
+      val toProtoDF = fromProtoDF.select(
+        functions
+          .to_protobuf($"value_from", "Proto2ExtensionTest", proto2FileDesc)
+          .as("value_to"))
+      val toFromProtoDF = toProtoDF.select(
+        functions
+          .from_protobuf($"value_to", "Proto2ExtensionTest", proto2FileDesc)
+          .as("value_to_from"))
+
+      checkAnswer(fromProtoDF.select($"value_from.*"), 
toFromProtoDF.select($"value_to_from.*"))
+    }
+  }
+
+  test("SPARK-55062: roundtrip - proto2 extension enum") {
+    withExtensionSupport {
+      val descWithExt = ProtobufUtils.buildDescriptor("Proto2ExtensionTest", 
Some(proto2FileDesc))
+      val descriptor = descWithExt.descriptor
+      val extensionFields = 
descWithExt.fullNamesToExtensions(descriptor.getFullName)
+
+      val extEnumField = extensionFields.find(_.getName == 
"extension_enum").get
+
+      val message = DynamicMessage
+        .newBuilder(descriptor)
+        .setField(descriptor.findFieldByName("name"), "enum_test")
+        .setField(descriptor.findFieldByName("id"), 99)
+        .setField(extEnumField, 
extEnumField.getEnumType.findValueByName("FIRST"))
+        .build()
+
+      val df = Seq(message.toByteArray).toDF("value")
+
+      val fromProtoDF = df.select(
+        functions.from_protobuf($"value", "Proto2ExtensionTest", 
proto2FileDesc).as("value_from"))
+
+      // Verify extension field value is correct
+      val row = fromProtoDF.select($"value_from.*").collect().head
+      assert(row.getAs[String]("name") == "enum_test")
+      assert(row.getAs[Int]("id") == 99)
+      assert(row.getAs[String]("extension_enum") == "FIRST")
+
+      val toProtoDF = fromProtoDF.select(
+        functions
+          .to_protobuf($"value_from", "Proto2ExtensionTest", proto2FileDesc)
+          .as("value_to"))
+      val toFromProtoDF = toProtoDF.select(
+        functions
+          .from_protobuf($"value_to", "Proto2ExtensionTest", proto2FileDesc)
+          .as("value_to_from"))
+
+      checkAnswer(fromProtoDF.select($"value_from.*"), 
toFromProtoDF.select($"value_to_from.*"))
+    }
+  }
+
+  test("SPARK-55062: roundtrip - proto2 extension nested message") {
+    withExtensionSupport {
+      val descWithExt = ProtobufUtils.buildDescriptor("Proto2ExtensionTest", 
Some(proto2FileDesc))
+      val descriptor = descWithExt.descriptor
+      val extensionFields = descWithExt.fullNamesToExtensions
+        .getOrElse(descriptor.getFullName, Seq.empty)
+
+      val extMessageField = extensionFields.find(_.getName == 
"extension_message").get
+
+      val nestedMessage = DynamicMessage
+        .newBuilder(extMessageField.getMessageType)
+        
.setField(extMessageField.getMessageType.findFieldByName("nested_value"), 
"nested_test")
+        .setField(extMessageField.getMessageType.findFieldByName("nested_id"), 
99)
+        .build()
+
+      val message = DynamicMessage
+        .newBuilder(descriptor)
+        .setField(descriptor.findFieldByName("name"), "main")
+        .setField(extMessageField, nestedMessage)
+        .build()
+
+      val df = Seq(message.toByteArray).toDF("value")
+
+      val fromProtoDF = df.select(
+        functions.from_protobuf($"value", "Proto2ExtensionTest", 
proto2FileDesc).as("value_from"))
+
+      // Verify extension field value is correct
+      val row = fromProtoDF.select($"value_from.*").collect().head
+      assert(row.getAs[String]("name") == "main")
+      val nestedRow = row.getAs[Row]("extension_message")
+      assert(nestedRow.getAs[String]("nested_value") == "nested_test")
+      assert(nestedRow.getAs[Int]("nested_id") == 99)
+
+      val toProtoDF = fromProtoDF.select(
+        functions
+          .to_protobuf($"value_from", "Proto2ExtensionTest", proto2FileDesc)
+          .as("value_to"))
+      val toFromProtoDF = toProtoDF.select(
+        functions
+          .from_protobuf($"value_to", "Proto2ExtensionTest", proto2FileDesc)
+          .as("value_to_from"))
+
+      checkAnswer(fromProtoDF.select($"value_from.*"), 
toFromProtoDF.select($"value_to_from.*"))
+    }
+  }
+
+  test("SPARK-55062: roundtrip - proto2 extension repeated") {
+    withExtensionSupport {
+      val descWithExt = ProtobufUtils.buildDescriptor("Proto2ExtensionTest", 
Some(proto2FileDesc))
+      val descriptor = descWithExt.descriptor
+      val extensionFields = descWithExt.fullNamesToExtensions
+        .getOrElse(descriptor.getFullName, Seq.empty)
+
+      val extRepeatedField = extensionFields.find(_.getName == 
"extension_repeated_int").get
+
+      val message = DynamicMessage
+        .newBuilder(descriptor)
+        .setField(descriptor.findFieldByName("name"), "repeated_test")
+        .addRepeatedField(extRepeatedField, 1)
+        .addRepeatedField(extRepeatedField, 2)
+        .addRepeatedField(extRepeatedField, 3)
+        .build()
+
+      val df = Seq(message.toByteArray).toDF("value")
+
+      val fromProtoDF = df.select(
+        functions.from_protobuf($"value", "Proto2ExtensionTest", 
proto2FileDesc).as("value_from"))
+
+      // Verify extension field value is correct
+      val row = fromProtoDF.select($"value_from.*").collect().head
+      assert(row.getAs[String]("name") == "repeated_test")
+      val repeatedValues = 
row.getSeq[Int](row.fieldIndex("extension_repeated_int"))
+      assert(repeatedValues == Seq(1, 2, 3))
+
+      val toProtoDF = fromProtoDF.select(
+        functions
+          .to_protobuf($"value_from", "Proto2ExtensionTest", proto2FileDesc)
+          .as("value_to"))
+      val toFromProtoDF = toProtoDF.select(
+        functions
+          .from_protobuf($"value_to", "Proto2ExtensionTest", proto2FileDesc)
+          .as("value_to_from"))
+
+      checkAnswer(fromProtoDF.select($"value_from.*"), 
toFromProtoDF.select($"value_to_from.*"))
+    }
+  }
+
+  test("SPARK-55062: roundtrip - proto2 extension defined in another message") 
{
+    withExtensionSupport {
+      val descWithExt = ProtobufUtils.buildDescriptor("Proto2ExtensionTest", 
Some(proto2FileDesc))
+      val descriptor = descWithExt.descriptor
+      val extensionFields = descWithExt.fullNamesToExtensions
+        .getOrElse(descriptor.getFullName, Seq.empty)
+
+      val nestedExtBoolField = extensionFields.find(_.getName == 
"nested_extension_bool")
+      assert(
+        nestedExtBoolField.isDefined,
+        "Should find extension defined in message 
ContainerWithNestedExtension")
+
+      val message = DynamicMessage
+        .newBuilder(descriptor)
+        .setField(descriptor.findFieldByName("name"), "nested_ext_test")
+        .setField(nestedExtBoolField.get, true)
+        .build()
+
+      val df = Seq(message.toByteArray).toDF("value")
+
+      val fromProtoDF = df.select(
+        functions.from_protobuf($"value", "Proto2ExtensionTest", 
proto2FileDesc).as("value_from"))
+
+      // Verify extension field value is correct
+      val row = fromProtoDF.select($"value_from.*").collect().head
+      assert(row.getAs[String]("name") == "nested_ext_test")
+      assert(row.getAs[Boolean]("nested_extension_bool"))
+
+      val toProtoDF = fromProtoDF.select(
+        functions
+          .to_protobuf($"value_from", "Proto2ExtensionTest", proto2FileDesc)
+          .as("value_to"))
+      val toFromProtoDF = toProtoDF.select(
+        functions
+          .from_protobuf($"value_to", "Proto2ExtensionTest", proto2FileDesc)
+          .as("value_to_from"))
+
+      checkAnswer(fromProtoDF.select($"value_from.*"), 
toFromProtoDF.select($"value_to_from.*"))
+    }
+  }
+
+  test("SPARK-55062: roundtrip - proto2 extension on nested message") {
+    withExtensionSupport {
+      val descWithExt =
+        ProtobufUtils.buildDescriptor("MessageWithExtendableNested", 
Some(proto2FileDesc))
+      val topLevelDescriptor = descWithExt.descriptor
+
+      val nestedFieldDesc = topLevelDescriptor.findFieldByName("nested")
+      val nestedMsgDescriptor = nestedFieldDesc.getMessageType
+      val nestedExtensions = descWithExt.fullNamesToExtensions
+        .getOrElse(nestedMsgDescriptor.getFullName, Seq.empty)
+
+      assert(
+        nestedExtensions.nonEmpty,
+        "Should find extensions for nested message type 
ExtendableNestedMessage")
+      val nestedExtField = nestedExtensions.find(_.getName == 
"nested_ext_field").get
+      val nestedExtInt = nestedExtensions.find(_.getName == 
"nested_ext_int").get
+
+      val nestedMessage = DynamicMessage
+        .newBuilder(nestedMsgDescriptor)
+        .setField(nestedMsgDescriptor.findFieldByName("nested_name"), 
"nested_name_value")
+        .setField(nestedMsgDescriptor.findFieldByName("nested_id"), 42)
+        .setField(nestedExtField, "ext_field_value")
+        .setField(nestedExtInt, 123)
+        .build()
+
+      val topLevelMessage = DynamicMessage
+        .newBuilder(topLevelDescriptor)
+        .setField(topLevelDescriptor.findFieldByName("top_level_name"), 
"top_name")
+        .setField(nestedFieldDesc, nestedMessage)
+        .build()
+
+      val df = Seq(topLevelMessage.toByteArray).toDF("value")
+
+      val fromProtoDF = df.select(
+        functions
+          .from_protobuf($"value", "MessageWithExtendableNested", 
proto2FileDesc)
+          .as("value_from"))
+
+      // Verify nested extension field values are correct
+      val row = fromProtoDF.select($"value_from.*").collect().head
+      assert(row.getAs[String]("top_level_name") == "top_name")
+      val nestedRow = row.getAs[Row]("nested")
+      assert(nestedRow.getAs[String]("nested_name") == "nested_name_value")
+      assert(nestedRow.getAs[Int]("nested_id") == 42)
+      assert(nestedRow.getAs[String]("nested_ext_field") == "ext_field_value")
+      assert(nestedRow.getAs[Int]("nested_ext_int") == 123)
+
+      // Verify roundtrip preserves values
+      val toProtoDF = fromProtoDF.select(
+        functions
+          .to_protobuf($"value_from", "MessageWithExtendableNested", 
proto2FileDesc)
+          .as("value_to"))
+      val toFromProtoDF = toProtoDF.select(
+        functions
+          .from_protobuf($"value_to", "MessageWithExtendableNested", 
proto2FileDesc)
+          .as("value_to_from"))
+
+      checkAnswer(fromProtoDF.select($"value_from.*"), 
toFromProtoDF.select($"value_to_from.*"))
+    }
+  }
+
+  test("SPARK-55062: roundtrip - proto2 extension in separate file") {
+    withExtensionSupport {
+      // NB: We manually construct a merged descriptor file because Maven's 
Protobuf plugin
+      // generates a single descriptor file for each .proto file.
+      val extDescriptorPath = 
protobufDescriptorFile("proto2_messages_ext.desc")
+      val fdsExt = FileDescriptorSet.parseFrom(
+        CommonProtobufUtils.readDescriptorFileContent(extDescriptorPath))
+      val fds = FileDescriptorSet.parseFrom(proto2FileDesc)
+      val combinedFds = FileDescriptorSet.newBuilder
+        .addAllFile(fdsExt.getFileList)
+        .addAllFile(fds.getFileList)
+        .build()
+      val combinedFdsBytes = combinedFds.toByteArray
+
+      val descWithExt =
+        ProtobufUtils.buildDescriptor("Proto2ExtensionTest", 
Some(combinedFds.toByteArray))
+      val descriptor = descWithExt.descriptor
+      val extensionFields = descWithExt.fullNamesToExtensions
+        .getOrElse(descriptor.getFullName, Seq.empty)
+
+      val extMessageField = extensionFields.find(_.getName == 
"cross_file_extension").get
+
+      val nestedMessage = DynamicMessage
+        .newBuilder(extMessageField.getMessageType)
+        .setField(extMessageField.getMessageType.findFieldByName("foo"), 1)
+        .build()
+
+      val message = DynamicMessage
+        .newBuilder(descriptor)
+        .setField(descriptor.findFieldByName("name"), "main")
+        .setField(extMessageField, nestedMessage)
+        .build()
+
+      val df = Seq(message.toByteArray).toDF("value")
+
+      val fromProtoDF = df.select(
+        functions
+          .from_protobuf($"value", "Proto2ExtensionTest", combinedFdsBytes)
+          .as("value_from"))
+
+      // Verify extension field value is correct
+      val row = fromProtoDF.select($"value_from.*").collect().head
+      assert(row.getAs[String]("name") == "main")
+      val crossFileExtRow = row.getAs[Row]("cross_file_extension")
+      assert(crossFileExtRow.getAs[Int]("foo") == 1)
+
+      val toProtoDF = fromProtoDF.select(
+        functions
+          .to_protobuf($"value_from", "Proto2ExtensionTest", combinedFdsBytes)
+          .as("value_to"))
+      val toFromProtoDF = toProtoDF.select(
+        functions
+          .from_protobuf($"value_to", "Proto2ExtensionTest", combinedFdsBytes)
+          .as("value_to_from"))
+
+      checkAnswer(fromProtoDF.select($"value_from.*"), 
toFromProtoDF.select($"value_to_from.*"))
+    }
+  }
+
+  test("SPARK-55062: proto2 extension field name collision with regular 
field") {
+    withExtensionSupport {
+      val descriptor =
+        ProtobufUtils
+          .buildDescriptor("Proto2ExtensionCollisionTest", 
Some(proto2FileDesc))
+          .descriptor
+      val message = DynamicMessage
+        .newBuilder(descriptor)
+        .setField(descriptor.findFieldByName("name"), "test")
+        .build()
+
+      val df = Seq(message.toByteArray).toDF("value")
+
+      // Some unwrapping required to get to the root error, as it is thrown 
during execution.
+      val e = intercept[SparkException] {
+        df.select(
+          functions.from_protobuf($"value", "Proto2ExtensionCollisionTest", 
proto2FileDesc))
+          .collect()
+      }
+      checkError(
+        exception = e.getCause
+          .asInstanceOf[AnalysisException]
+          .getCause
+          .asInstanceOf[AnalysisException],
+        condition = "PROTOBUF_FIELD_MISSING",
+        parameters = Map(
+          "field" -> "name",
+          "protobufSchema" -> "top-level record",
+          "matchSize" -> "2",
+          "matches" -> "[name, NAME]"))
+    }
+  }
+
+  test("SPARK-55062: schema evolution - data without extensions read with 
extension schema") {
+    withExtensionSupport {
+      val descriptorWithoutExt =
+        ProtobufUtils.buildDescriptor("Proto2ExtensionTestBase", 
Some(proto2FileDesc)).descriptor
+      val messageWithoutExtensions = DynamicMessage
+        .newBuilder(descriptorWithoutExt)
+        .setField(descriptorWithoutExt.findFieldByName("name"), "base_name")
+        .setField(descriptorWithoutExt.findFieldByName("id"), 42)
+        .build()
+
+      val df = Seq(messageWithoutExtensions.toByteArray).toDF("value")
+
+      // Read using schema WITH extensions
+      val fromProtoDF = df.select(
+        functions.from_protobuf($"value", "Proto2ExtensionTest", 
proto2FileDesc).as("parsed"))
+      val result = fromProtoDF.select($"parsed.*").collect()
+      assert(result.length == 1)
+      val row = result(0)
+
+      // Regular fields have expected values, while extension fields are 
null/empty
+      assert(row.getAs[String]("name") == "base_name")
+      assert(row.getAs[Int]("id") == 42)
+      assert(row.getAs[String]("extension_string") == null)
+      assert(row.isNullAt(row.fieldIndex("extension_int")))
+      assert(row.isNullAt(row.fieldIndex("extension_enum")))
+      assert(row.isNullAt(row.fieldIndex("extension_message")))
+      assert(row.isNullAt(row.fieldIndex("nested_extension_bool")))
+      assert(row.getSeq[Int](row.fieldIndex("extension_repeated_int")).isEmpty)
+    }
+  }
+
+  test("SPARK-55062: Java class fallback drops extensions") {
+    withExtensionSupport {
+      val descWithExt = ProtobufUtils.buildDescriptor("Proto2ExtensionTest", 
Some(proto2FileDesc))
+      val descriptor = descWithExt.descriptor
+      val extensionFields = 
descWithExt.fullNamesToExtensions(descriptor.getFullName)
+
+      val extStringField = extensionFields.find(_.getName == 
"extension_string").get
+
+      val message = DynamicMessage
+        .newBuilder(descriptor)
+        .setField(descriptor.findFieldByName("name"), "test_name")
+        .setField(descriptor.findFieldByName("id"), 42)
+        .setField(extStringField, "ext_value")
+        .build()
+      val df = Seq(message.toByteArray).toDF("value")
+
+      // Read using Java class, which does not support extensions
+      val fromProtoDF = df.select(
+        functions
+          .from_protobuf($"value", proto2JavaClassNamePrefix + 
"Proto2ExtensionTest")
+          .as("parsed"))
+
+      // Schema should only have regular fields, not extension fields
+      val row = fromProtoDF.select($"parsed.*").collect().head
+      val schema = row.schema
+      assert(row.getAs[String]("name") == "test_name")
+      assert(row.getAs[Int]("id") == 42)
+      assert(!schema.fieldNames.contains("extension_string"))
+      assert(!schema.fieldNames.contains("extension_int"))
+    }
+  }
+
+  test("SPARK-55062: from_protobuf drops extension fields by default") {
+    // We need extensions on when calling buildDescriptor so that we can 
access the extension
+    // fields for test setup and assertions.
+    val descWithExt = withExtensionSupport {
+      ProtobufUtils.buildDescriptor("Proto2ExtensionTest", 
Some(proto2FileDesc))
+    }
+    val descriptor = descWithExt.descriptor
+    val extensionFields = 
descWithExt.fullNamesToExtensions(descriptor.getFullName)
+
+    val extStringField = extensionFields.find(_.getName == 
"extension_string").get
+    val message = DynamicMessage
+      .newBuilder(descriptor)
+      .setField(descriptor.findFieldByName("name"), "test_name")
+      .setField(descriptor.findFieldByName("id"), 42)
+      .setField(extStringField, "ext_value")
+      .build()
+
+    val df = Seq(message.toByteArray).toDF("value")
+
+    val fromProtoDF = df.select(
+      functions
+        .from_protobuf($"value", "Proto2ExtensionTest", proto2FileDesc)
+        .as("value_from"))
+    val row = fromProtoDF.select($"value_from.*").collect().head
+    val schema = row.schema
+    assert(row.getAs[String]("name") == "test_name")
+    assert(row.getAs[Int]("id") == 42)
+    assert(!schema.fieldNames.contains("extension_string"))
+    assert(!schema.fieldNames.contains("extension_int"))
+  }
+
+  test("SPARK-55062: to_protobuf does not recognize extension fields by 
default") {
+    val schema = StructType(
+      StructField(
+        "Proto2ExtensionTest",
+        StructType(
+          StructField("name", StringType) ::
+            StructField("id", IntegerType) ::
+            StructField("extension_string", StringType) ::
+            StructField("extension_int", IntegerType) :: Nil)) :: Nil)
+    val df = spark.createDataFrame(
+      spark.sparkContext.parallelize(Seq(Row(Row("test_name", 42, 
"extension_value", 123)))),
+      schema)
+
+    // We expect an error because from to_protobuf's perspective, the 
DataFrame contains extra
+    // fields that are not in the message descriptor.
+    val e = intercept[AnalysisException] {
+      val toProtoDF = df
+        .select(
+          functions
+            .to_protobuf($"Proto2ExtensionTest", "Proto2ExtensionTest", 
proto2FileDesc)
+            .as("to_proto"))
+        .collect()
+    }
+    val toType = "\"STRUCT<name: STRING, id: INT, extension_string: STRING, 
extension_int: INT>\""
+    checkError(
+      exception = e,
+      condition = "UNABLE_TO_CONVERT_TO_PROTOBUF_MESSAGE_TYPE",
+      parameters = Map("protobufType" -> "Proto2ExtensionTest", "toType" -> 
toType))
+  }
+
+  // Extension support is disabled by default.
+  private def withExtensionSupport[T](f: => T): T = {
+    val old = conf.getConf(PROTOBUF_EXTENSIONS_SUPPORT_ENABLED)
+    conf.setConf(PROTOBUF_EXTENSIONS_SUPPORT_ENABLED, true)
+    try {
+      f
+    } finally {
+      conf.setConf(PROTOBUF_EXTENSIONS_SUPPORT_ENABLED, old)
+    }
+  }
+}
diff --git 
a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala
 
b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala
index e46fcb1a1735..b283e4fb8fa9 100644
--- 
a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala
+++ 
b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala
@@ -163,8 +163,12 @@ class ProtobufFunctionsSuite extends QueryTest with 
SharedSparkSession with Prot
   }
 
   test("roundtrip in from_protobuf and to_protobuf - Repeated Message Once") {
-    val repeatedMessageDesc = ProtobufUtils.buildDescriptor(testFileDesc, 
"RepeatedMessage")
-    val basicMessageDesc = ProtobufUtils.buildDescriptor(testFileDesc, 
"BasicMessage")
+    val repeatedMessageDesc = ProtobufUtils
+      .buildDescriptor("RepeatedMessage", Some(testFileDesc))
+      .descriptor
+    val basicMessageDesc = ProtobufUtils
+      .buildDescriptor("BasicMessage", Some(testFileDesc))
+      .descriptor
 
     val basicMessage = DynamicMessage
       .newBuilder(basicMessageDesc)
@@ -200,8 +204,12 @@ class ProtobufFunctionsSuite extends QueryTest with 
SharedSparkSession with Prot
   }
 
   test("roundtrip in from_protobuf and to_protobuf - Repeated Message Twice") {
-    val repeatedMessageDesc = ProtobufUtils.buildDescriptor(testFileDesc, 
"RepeatedMessage")
-    val basicMessageDesc = ProtobufUtils.buildDescriptor(testFileDesc, 
"BasicMessage")
+    val repeatedMessageDesc = ProtobufUtils
+      .buildDescriptor("RepeatedMessage", Some(testFileDesc))
+      .descriptor
+    val basicMessageDesc = ProtobufUtils
+      .buildDescriptor("BasicMessage", Some(testFileDesc))
+      .descriptor
 
     val basicMessage1 = DynamicMessage
       .newBuilder(basicMessageDesc)
@@ -251,7 +259,9 @@ class ProtobufFunctionsSuite extends QueryTest with 
SharedSparkSession with Prot
   }
 
   test("roundtrip in from_protobuf and to_protobuf - Map") {
-    val messageMapDesc = ProtobufUtils.buildDescriptor(testFileDesc, 
"SimpleMessageMap")
+    val messageMapDesc = ProtobufUtils
+      .buildDescriptor("SimpleMessageMap", Some(testFileDesc))
+      .descriptor
 
     val mapStr1 = DynamicMessage
       .newBuilder(messageMapDesc.findNestedTypeByName("StringMapdataEntry"))
@@ -345,8 +355,12 @@ class ProtobufFunctionsSuite extends QueryTest with 
SharedSparkSession with Prot
   }
 
   test("roundtrip in from_protobuf and to_protobuf - Enum") {
-    val messageEnumDesc = ProtobufUtils.buildDescriptor(testFileDesc, 
"SimpleMessageEnum")
-    val basicEnumDesc = ProtobufUtils.buildDescriptor(testFileDesc, 
"BasicEnumMessage")
+    val messageEnumDesc = ProtobufUtils
+      .buildDescriptor("SimpleMessageEnum", Some(testFileDesc))
+      .descriptor
+    val basicEnumDesc = ProtobufUtils
+      .buildDescriptor("BasicEnumMessage", Some(testFileDesc))
+      .descriptor
 
     val dynamicMessage = DynamicMessage
       .newBuilder(messageEnumDesc)
@@ -389,9 +403,15 @@ class ProtobufFunctionsSuite extends QueryTest with 
SharedSparkSession with Prot
   }
 
   test("round trip in from_protobuf and to_protobuf - Multiple Message") {
-    val messageMultiDesc = ProtobufUtils.buildDescriptor(testFileDesc, 
"MultipleExample")
-    val messageIncludeDesc = ProtobufUtils.buildDescriptor(testFileDesc, 
"IncludedExample")
-    val messageOtherDesc = ProtobufUtils.buildDescriptor(testFileDesc, 
"OtherExample")
+    val messageMultiDesc = ProtobufUtils
+      .buildDescriptor("MultipleExample", Some(testFileDesc))
+      .descriptor
+    val messageIncludeDesc = ProtobufUtils
+      .buildDescriptor("IncludedExample", Some(testFileDesc))
+      .descriptor
+    val messageOtherDesc = ProtobufUtils
+      .buildDescriptor("OtherExample", Some(testFileDesc))
+      .descriptor
 
     val otherMessage = DynamicMessage
       .newBuilder(messageOtherDesc)
@@ -470,8 +490,8 @@ class ProtobufFunctionsSuite extends QueryTest with 
SharedSparkSession with Prot
     val catalystTypesFile = protobufDescriptorFile("catalyst_types.desc")
     val descBytes = 
CommonProtobufUtils.readDescriptorFileContent(catalystTypesFile)
 
-    val oldProducer = ProtobufUtils.buildDescriptor(descBytes, "oldProducer")
-    val newConsumer = ProtobufUtils.buildDescriptor(descBytes, "newConsumer")
+    val oldProducer = ProtobufUtils.buildDescriptor("oldProducer", 
Some(descBytes)).descriptor
+    val newConsumer = ProtobufUtils.buildDescriptor("newConsumer", 
Some(descBytes)).descriptor
 
     val oldProducerMessage = DynamicMessage
       .newBuilder(oldProducer)
@@ -512,8 +532,8 @@ class ProtobufFunctionsSuite extends QueryTest with 
SharedSparkSession with Prot
     val catalystTypesFile = protobufDescriptorFile("catalyst_types.desc")
     val descBytes = 
CommonProtobufUtils.readDescriptorFileContent(catalystTypesFile)
 
-    val newProducer = ProtobufUtils.buildDescriptor(descBytes, "newProducer")
-    val oldConsumer = ProtobufUtils.buildDescriptor(descBytes, "oldConsumer")
+    val newProducer = ProtobufUtils.buildDescriptor("newProducer", 
Some(descBytes)).descriptor
+    val oldConsumer = ProtobufUtils.buildDescriptor("oldConsumer", 
Some(descBytes)).descriptor
 
     val newProducerMessage = DynamicMessage
       .newBuilder(newProducer)
@@ -560,7 +580,9 @@ class ProtobufFunctionsSuite extends QueryTest with 
SharedSparkSession with Prot
 
     val binary = toProtobuf.first().get(0).asInstanceOf[Array[Byte]]
 
-    val messageDescriptor = ProtobufUtils.buildDescriptor(testFileDesc, 
"requiredMsg")
+    val messageDescriptor = ProtobufUtils
+      .buildDescriptor("requiredMsg", Some(testFileDesc))
+      .descriptor
     val actualMessage = DynamicMessage.parseFrom(messageDescriptor, binary)
 
     assert(actualMessage.getField(messageDescriptor.findFieldByName("key"))
@@ -582,7 +604,9 @@ class ProtobufFunctionsSuite extends QueryTest with 
SharedSparkSession with Prot
   }
 
   test("from_protobuf filter to_protobuf") {
-    val basicMessageDesc = ProtobufUtils.buildDescriptor(testFileDesc, 
"BasicMessage")
+    val basicMessageDesc = ProtobufUtils
+      .buildDescriptor("BasicMessage", Some(testFileDesc))
+      .descriptor
 
     val basicMessage = DynamicMessage
       .newBuilder(basicMessageDesc)
@@ -714,7 +738,7 @@ class ProtobufFunctionsSuite extends QueryTest with 
SharedSparkSession with Prot
   }
 
   test("Verify OneOf field between from_protobuf -> to_protobuf and struct -> 
from_protobuf") {
-    val descriptor = ProtobufUtils.buildDescriptor(testFileDesc, "OneOfEvent")
+    val descriptor = ProtobufUtils.buildDescriptor("OneOfEvent", 
Some(testFileDesc)).descriptor
     val oneOfEvent = OneOfEvent.newBuilder()
       .setKey("key")
       .setCol1(123)
@@ -804,7 +828,7 @@ class ProtobufFunctionsSuite extends QueryTest with 
SharedSparkSession with Prot
   }
 
   test("Verify recursion field with complex schema with 
recursive.fields.max.depth") {
-    val descriptor = ProtobufUtils.buildDescriptor(testFileDesc, "Employee")
+    val descriptor = ProtobufUtils.buildDescriptor("Employee", 
Some(testFileDesc)).descriptor
 
     val manager = 
Employee.newBuilder().setFirstName("firstName").setLastName("lastName").build()
     val em2 = EM2.newBuilder().setTeamsize(100).setEm2Manager(manager).build()
@@ -846,7 +870,9 @@ class ProtobufFunctionsSuite extends QueryTest with 
SharedSparkSession with Prot
 
   test("Verify OneOf field with recursive fields between from_protobuf -> 
to_protobuf." +
     "and struct -> from_protobuf") {
-    val descriptor = ProtobufUtils.buildDescriptor(testFileDesc, 
"OneOfEventWithRecursion")
+    val descriptor = ProtobufUtils
+      .buildDescriptor("OneOfEventWithRecursion", Some(testFileDesc))
+      .descriptor
 
     val nestedTwo = OneOfEventWithRecursion.newBuilder()
       .setKey("keyNested2").setValue("valueNested2").build()
diff --git 
a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufSerdeSuite.scala
 
b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufSerdeSuite.scala
index f3bd49e1b24a..71bcb6b51a56 100644
--- 
a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufSerdeSuite.scala
+++ 
b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufSerdeSuite.scala
@@ -47,7 +47,8 @@ class ProtobufSerdeSuite extends SharedSparkSession with 
ProtobufTestBase {
 
   test("Test basic conversion") {
     withFieldMatchType { fieldMatch =>
-      val protoFile = ProtobufUtils.buildDescriptor(testFileDesc, 
"SerdeBasicMessage")
+      val protoFile =
+        ProtobufUtils.buildDescriptor("SerdeBasicMessage", 
Some(testFileDesc)).descriptor
 
       val dynamicMessageFoo = DynamicMessage
         .newBuilder(protoFile.getFile.findMessageTypeByName("Foo"))
@@ -71,7 +72,8 @@ class ProtobufSerdeSuite extends SharedSparkSession with 
ProtobufTestBase {
     // This test verifies that optional fields can be missing from input 
Catalyst schema
     // while serializing rows to protobuf.
 
-    val desc = ProtobufUtils.buildDescriptor(proto2Desc, 
"FoobarWithRequiredFieldBar")
+    val desc =
+      ProtobufUtils.buildDescriptor("FoobarWithRequiredFieldBar", 
Some(proto2Desc)).descriptor
 
     // Confirm desc contains optional field 'foo' and required field bar.
     assert(desc.getFields.size() == 2)
@@ -90,7 +92,8 @@ class ProtobufSerdeSuite extends SharedSparkSession with 
ProtobufTestBase {
   }
 
   test("Fail to convert with field type mismatch") {
-    val protoFile = ProtobufUtils.buildDescriptor(testFileDesc, 
"MissMatchTypeInRoot")
+    val protoFile =
+      ProtobufUtils.buildDescriptor("MissMatchTypeInRoot", 
Some(testFileDesc)).descriptor
     withFieldMatchType { fieldMatch =>
       assertFailedConversionMessage(
         protoFile,
@@ -113,7 +116,8 @@ class ProtobufSerdeSuite extends SharedSparkSession with 
ProtobufTestBase {
   }
 
   test("Fail to convert with missing nested Protobuf fields for serializer") {
-    val protoFile = ProtobufUtils.buildDescriptor(testFileDesc, 
"FieldMissingInProto")
+    val protoFile =
+      ProtobufUtils.buildDescriptor("FieldMissingInProto", 
Some(testFileDesc)).descriptor
 
     val nonnullCatalyst = new StructType()
       .add("foo", new StructType().add("bar", IntegerType, nullable = false))
@@ -142,7 +146,7 @@ class ProtobufSerdeSuite extends SharedSparkSession with 
ProtobufTestBase {
   test("Fail to convert with deeply nested field type mismatch") {
     val protoFile = ProtobufUtils.buildDescriptorFromJavaClass(
       s"${javaClassNamePrefix}MissMatchTypeInDeepNested"
-    )
+    ).descriptor
     val catalyst = new StructType().add("top", CATALYST_STRUCT)
 
     withFieldMatchType { fieldMatch =>
@@ -169,12 +173,13 @@ class ProtobufSerdeSuite extends SharedSparkSession with 
ProtobufTestBase {
   }
 
   test("Fail to convert with missing Catalyst fields") {
-    val protoFile = ProtobufUtils.buildDescriptor(testFileDesc, 
"FieldMissingInProto")
+    val protoFile =
+      ProtobufUtils.buildDescriptor("FieldMissingInProto", 
Some(testFileDesc)).descriptor
 
     val foobarSQLType = structFromDDL("struct<foo string>") // "bar" is 
missing.
 
     assertFailedConversionMessage(
-      ProtobufUtils.buildDescriptor(proto2Desc, "FoobarWithRequiredFieldBar"),
+      ProtobufUtils.buildDescriptor("FoobarWithRequiredFieldBar", 
Some(proto2Desc)).descriptor,
       Serializer,
       BY_NAME,
       catalystSchema = foobarSQLType,
@@ -189,14 +194,17 @@ class ProtobufSerdeSuite extends SharedSparkSession with 
ProtobufTestBase {
     withFieldMatchType(Deserializer.create(CATALYST_STRUCT, protoFile, _))
 
     val protoNestedFile = ProtobufUtils
-      .buildDescriptor(proto2Desc, "NestedFoobarWithRequiredFieldBar")
+      .buildDescriptor("NestedFoobarWithRequiredFieldBar", Some(proto2Desc))
+      .descriptor
 
     val nestedFoobarSQLType = structFromDDL(
       "struct<nested_foobar: struct<foo string>>" // "bar" field is missing.
     )
     // serializing with extra fails if required field is missing in inner 
struct
     assertFailedConversionMessage(
-      ProtobufUtils.buildDescriptor(proto2Desc, 
"NestedFoobarWithRequiredFieldBar"),
+      ProtobufUtils
+        .buildDescriptor("NestedFoobarWithRequiredFieldBar", Some(proto2Desc))
+        .descriptor,
       Serializer,
       BY_NAME,
       catalystSchema = nestedFoobarSQLType,
@@ -216,9 +224,8 @@ class ProtobufSerdeSuite extends SharedSparkSession with 
ProtobufTestBase {
 
     val e1 = intercept[AnalysisException] {
       ProtobufUtils.buildDescriptor(
-        CommonProtobufUtils.readDescriptorFileContent(fileDescFile),
-        "SerdeBasicMessage"
-      )
+        "SerdeBasicMessage",
+        Some(CommonProtobufUtils.readDescriptorFileContent(fileDescFile)))
     }
 
     checkError(
@@ -234,9 +241,7 @@ class ProtobufSerdeSuite extends SharedSparkSession with 
ProtobufTestBase {
 
 
     val e2 = intercept[AnalysisException] {
-      ProtobufUtils.buildDescriptor(
-        basicMessageDescWithoutImports,
-        "BasicMessage")
+      ProtobufUtils.buildDescriptor("BasicMessage", 
Some(basicMessageDescWithoutImports))
     }
 
     checkError(
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 495ece401949..b2a2b7027394 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -6896,6 +6896,15 @@ object SQLConf {
       .booleanConf
       .createWithDefault(Utils.isTesting)
 
+  val PROTOBUF_EXTENSIONS_SUPPORT_ENABLED =
+    buildConf("spark.sql.function.protobufExtensions.enabled")
+      .doc("When true, the from_protobuf and to_protobuf operators will 
support proto2 " +
+        "extensions when a binary file descriptor set is provided. This 
property will have no " +
+        "effect for the overloads taking a Java class name instead of a file 
descriptor set.")
+      .version("4.2.0")
+      .booleanConf
+      .createWithDefault(false)
+
   /**
    * Holds information about keys that have been deprecated.
    *
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/checkpointing/OffsetSeq.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/checkpointing/OffsetSeq.scala
index 23c266e82891..dff09c736a83 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/checkpointing/OffsetSeq.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/checkpointing/OffsetSeq.scala
@@ -203,7 +203,7 @@ object OffsetSeqMetadata extends Logging {
     STATE_STORE_ROCKSDB_FORMAT_VERSION, 
STATE_STORE_ROCKSDB_MERGE_OPERATOR_VERSION,
     STATEFUL_OPERATOR_USE_STRICT_DISTRIBUTION,
     PRUNE_FILTERS_CAN_PRUNE_STREAMING_SUBPLAN, 
STREAMING_STATE_STORE_ENCODING_FORMAT,
-    STATE_STORE_ROW_CHECKSUM_ENABLED
+    STATE_STORE_ROW_CHECKSUM_ENABLED, PROTOBUF_EXTENSIONS_SUPPORT_ENABLED
   )
 
   /**
@@ -251,7 +251,8 @@ object OffsetSeqMetadata extends Logging {
     PRUNE_FILTERS_CAN_PRUNE_STREAMING_SUBPLAN.key -> "true",
     STREAMING_STATE_STORE_ENCODING_FORMAT.key -> "unsaferow",
     STATE_STORE_ROW_CHECKSUM_ENABLED.key -> "false",
-    STATE_STORE_ROCKSDB_MERGE_OPERATOR_VERSION.key -> "1"
+    STATE_STORE_ROCKSDB_MERGE_OPERATOR_VERSION.key -> "1",
+    PROTOBUF_EXTENSIONS_SUPPORT_ENABLED.key -> "false"
   )
 
   def readValue[T](metadataLog: OffsetSeqMetadataBase, confKey: 
ConfigEntry[T]): String = {
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ProtobufExtensionsSupportOffsetLogSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ProtobufExtensionsSupportOffsetLogSuite.scala
new file mode 100644
index 000000000000..0b9914e1dff0
--- /dev/null
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ProtobufExtensionsSupportOffsetLogSuite.scala
@@ -0,0 +1,110 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import org.apache.spark.sql.execution.streaming.checkpointing.{OffsetSeqBase, 
OffsetSeqLog, OffsetSeqMetadata}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.test.SharedSparkSession
+
+/**
+ * Tests for the spark.sql.function.protobufExtensions.enabled flag and its 
persistence in the
+ * offset log.
+ */
+class ProtobufExtensionsSupportOffsetLogSuite extends SharedSparkSession {
+
+  test("protobuf extensions support disabled by default for new queries") {
+    val offsetSeqMetadata =
+      OffsetSeqMetadata(batchWatermarkMs = 0, batchTimestampMs = 0, spark.conf)
+    assert(
+      
offsetSeqMetadata.conf.get(SQLConf.PROTOBUF_EXTENSIONS_SUPPORT_ENABLED.key) ===
+        Some(false.toString))
+  }
+
+  test("protobuf extensions support enabled when session sets it to true") {
+    val protobufExtConf = SQLConf.PROTOBUF_EXTENSIONS_SUPPORT_ENABLED.key
+    withSQLConf(protobufExtConf -> true.toString) {
+      val offsetSeqMetadata =
+        OffsetSeqMetadata(batchWatermarkMs = 0, batchTimestampMs = 0, 
spark.conf)
+      assert(offsetSeqMetadata.conf.get(protobufExtConf) === 
Some(true.toString))
+    }
+  }
+
+  test(
+    "protobuf extensions support uses default false for old checkpoint when 
enabled in session") {
+    val protobufExtConf = SQLConf.PROTOBUF_EXTENSIONS_SUPPORT_ENABLED.key
+    withSQLConf(protobufExtConf -> true.toString) {
+      val existingChkpt = "offset-log-version-2.1.0"
+      val (_, offsetSeq) = readFromResource(existingChkpt)
+      val offsetSeqMetadata = offsetSeq.metadataOpt.get
+      // Not present in existing checkpoint
+      assert(offsetSeqMetadata.conf.get(protobufExtConf) === None)
+
+      val clonedSqlConf = spark.sessionState.conf.clone()
+      OffsetSeqMetadata.setSessionConf(offsetSeqMetadata, clonedSqlConf)
+      
assert(!clonedSqlConf.getConf(SQLConf.PROTOBUF_EXTENSIONS_SUPPORT_ENABLED))
+    }
+  }
+
+  test(
+    "protobuf extensions support uses default false for old checkpoint when 
unset in session") {
+    val existingChkpt = "offset-log-version-2.1.0"
+    val (_, offsetSeq) = readFromResource(existingChkpt)
+    val offsetSeqMetadata = offsetSeq.metadataOpt.get
+    val protobufExtConf = SQLConf.PROTOBUF_EXTENSIONS_SUPPORT_ENABLED.key
+    // Not present in existing checkpoint
+    assert(offsetSeqMetadata.conf.get(protobufExtConf) === None)
+
+    val clonedSqlConf = spark.sessionState.conf.clone()
+    OffsetSeqMetadata.setSessionConf(offsetSeqMetadata, clonedSqlConf)
+    assert(!clonedSqlConf.getConf(SQLConf.PROTOBUF_EXTENSIONS_SUPPORT_ENABLED))
+  }
+
+  test("protobuf extensions support in existing checkpoint takes precedence 
over session value") {
+    val protobufExtConf = SQLConf.PROTOBUF_EXTENSIONS_SUPPORT_ENABLED.key
+
+    // Enabled when checkpoint = enabled and session = disabled
+    withSQLConf(protobufExtConf -> false.toString) {
+      val offsetSeqMetadata = OffsetSeqMetadata(
+        batchWatermarkMs = 0,
+        batchTimestampMs = 0,
+        Map(protobufExtConf -> true.toString))
+
+      val clonedSqlConf = spark.sessionState.conf.clone()
+      OffsetSeqMetadata.setSessionConf(offsetSeqMetadata, clonedSqlConf)
+      
assert(clonedSqlConf.getConf(SQLConf.PROTOBUF_EXTENSIONS_SUPPORT_ENABLED))
+    }
+
+    // Disabled when checkpoint = disabled and session = enabled
+    withSQLConf(protobufExtConf -> true.toString) {
+      val offsetSeqMetadata = OffsetSeqMetadata(
+        batchWatermarkMs = 0,
+        batchTimestampMs = 0,
+        Map(protobufExtConf -> false.toString))
+
+      val clonedSqlConf = spark.sessionState.conf.clone()
+      OffsetSeqMetadata.setSessionConf(offsetSeqMetadata, clonedSqlConf)
+      
assert(!clonedSqlConf.getConf(SQLConf.PROTOBUF_EXTENSIONS_SUPPORT_ENABLED))
+    }
+  }
+
+  private def readFromResource(dir: String): (Long, OffsetSeqBase) = {
+    val input = getClass.getResource(s"/structured-streaming/$dir")
+    val log = new OffsetSeqLog(spark, input.toString)
+    log.getLatest().get
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to