This is an automated email from the ASF dual-hosted git repository.

kabhwan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new fd9e5760bae [SPARK-40657] Add support for Java classes in Protobuf 
functions
fd9e5760bae is described below

commit fd9e5760bae847f47c9c108f0e58814748e0d9b1
Author: Raghu Angadi <raghu.ang...@databricks.com>
AuthorDate: Fri Oct 21 15:46:50 2022 +0900

    [SPARK-40657] Add support for Java classes in Protobuf functions
    
    ### What changes were proposed in this pull request?
    
    Adds support for compiled Java classes to Protobuf functions. This is 
tested with Protobuf v3 classes. V2 vs V3 issues will be handled in a separate 
PR. The main changes in this PR:
    
     - Changes to top level API:
        - Adds new version that takes just the class name.
        - Changes the order of arguments for existing API with descriptor files 
(`messageName` and `descFilePath` are swapped).
     - Protobuf utils methods to create descriptor from Java class name.
     - Many unit tests are update to check both versions : (1) with descriptor 
file and (2) with Java class name.
     - Maven build updates to generate Java classes to use in tests.
     - Miscellaneous changes:
        - Adds `proto` to package name in `proto` files used in tests.
        - A few TODO comments about improvements
    
    ### Why are the changes needed?
    Java compiled classes is a common method for users to provide Protobuf 
definitions.
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    This updates interface, but for a new feature in active development.
    
    ### How was this patch tested?
     - Unit tests
    
    Closes #38286 from rangadi/protobuf-java.
    
    Authored-by: Raghu Angadi <raghu.ang...@databricks.com>
    Signed-off-by: Jungtaek Lim <kabhwan.opensou...@gmail.com>
---
 connector/protobuf/pom.xml                         |  23 +-
 .../sql/protobuf/CatalystDataToProtobuf.scala      |  10 +-
 .../sql/protobuf/ProtobufDataToCatalyst.scala      |  34 ++-
 .../org/apache/spark/sql/protobuf/functions.scala  |  58 +++-
 .../spark/sql/protobuf/utils/ProtobufUtils.scala   |  65 ++++-
 .../sql/protobuf/utils/SchemaConverters.scala      |   4 +
 .../test/resources/protobuf/catalyst_types.proto   |   4 +-
 .../test/resources/protobuf/functions_suite.proto  |   4 +-
 .../src/test/resources/protobuf/serde_suite.proto  |   6 +-
 .../ProtobufCatalystDataConversionSuite.scala      |  97 +++++--
 .../sql/protobuf/ProtobufFunctionsSuite.scala      | 318 +++++++++++++--------
 .../spark/sql/protobuf/ProtobufSerdeSuite.scala    |   9 +-
 project/SparkBuild.scala                           |   6 +-
 python/pyspark/sql/protobuf/functions.py           |  22 +-
 14 files changed, 437 insertions(+), 223 deletions(-)

diff --git a/connector/protobuf/pom.xml b/connector/protobuf/pom.xml
index 0515f128b8d..b934c7f831a 100644
--- a/connector/protobuf/pom.xml
+++ b/connector/protobuf/pom.xml
@@ -83,7 +83,6 @@
       <version>${protobuf.version}</version>
       <scope>compile</scope>
     </dependency>
-
   </dependencies>
   <build>
     
<outputDirectory>target/scala-${scala.binary.version}/classes</outputDirectory>
@@ -110,6 +109,28 @@
           </relocations>
         </configuration>
       </plugin>
+      <plugin>
+        <groupId>com.github.os72</groupId>
+        <artifactId>protoc-jar-maven-plugin</artifactId>
+        <version>3.11.4</version>
+        <!-- Generates Java classes for tests. TODO(Raghu): Generate 
descriptor files too. -->
+        <executions>
+          <execution>
+            <phase>generate-test-sources</phase>
+            <goals>
+              <goal>run</goal>
+            </goals>
+            <configuration>
+              
<protocArtifact>com.google.protobuf:protoc:${protobuf.version}</protocArtifact>
+              <protocVersion>${protobuf.version}</protocVersion>
+              <inputDirectories>
+                <include>src/test/resources/protobuf</include>
+              </inputDirectories>
+              <addSources>test</addSources>
+            </configuration>
+          </execution>
+        </executions>
+      </plugin>
     </plugins>
   </build>
 </project>
diff --git 
a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/CatalystDataToProtobuf.scala
 
b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/CatalystDataToProtobuf.scala
index 145100268c2..b9f7907ea8c 100644
--- 
a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/CatalystDataToProtobuf.scala
+++ 
b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/CatalystDataToProtobuf.scala
@@ -25,17 +25,17 @@ import org.apache.spark.sql.types.{BinaryType, DataType}
 
 private[protobuf] case class CatalystDataToProtobuf(
     child: Expression,
-    descFilePath: String,
-    messageName: String)
+    messageName: String,
+    descFilePath: Option[String] = None)
     extends UnaryExpression {
 
   override def dataType: DataType = BinaryType
 
-  @transient private lazy val protoType =
-    ProtobufUtils.buildDescriptor(descFilePath, messageName)
+  @transient private lazy val protoDescriptor =
+    ProtobufUtils.buildDescriptor(messageName, descFilePathOpt = descFilePath)
 
   @transient private lazy val serializer =
-    new ProtobufSerializer(child.dataType, protoType, child.nullable)
+    new ProtobufSerializer(child.dataType, protoDescriptor, child.nullable)
 
   override def nullSafeEval(input: Any): Any = {
     val dynamicMessage = 
serializer.serialize(input).asInstanceOf[DynamicMessage]
diff --git 
a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDataToCatalyst.scala
 
b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDataToCatalyst.scala
index f08f8767997..cad2442f10c 100644
--- 
a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDataToCatalyst.scala
+++ 
b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDataToCatalyst.scala
@@ -31,9 +31,9 @@ import org.apache.spark.sql.types.{AbstractDataType, 
BinaryType, DataType, Struc
 
 private[protobuf] case class ProtobufDataToCatalyst(
     child: Expression,
-    descFilePath: String,
     messageName: String,
-    options: Map[String, String])
+    descFilePath: Option[String] = None,
+    options: Map[String, String] = Map.empty)
     extends UnaryExpression
     with ExpectsInputTypes {
 
@@ -55,10 +55,14 @@ private[protobuf] case class ProtobufDataToCatalyst(
   private lazy val protobufOptions = ProtobufOptions(options)
 
   @transient private lazy val messageDescriptor =
-    ProtobufUtils.buildDescriptor(descFilePath, messageName)
+    ProtobufUtils.buildDescriptor(messageName, descFilePath)
+    // TODO: Avoid carrying the file name. Read the contents of descriptor 
file only once
+    //       at the start. Rest of the runs should reuse the buffer. 
Otherwise, it could
+    //       cause inconsistencies if the file contents are changed the user 
after a few days.
+    //       Same for the write side in [[CatalystDataToProtobuf]].
 
   @transient private lazy val fieldsNumbers =
-    messageDescriptor.getFields.asScala.map(f => f.getNumber)
+    messageDescriptor.getFields.asScala.map(f => f.getNumber).toSet
 
   @transient private lazy val deserializer = new 
ProtobufDeserializer(messageDescriptor, dataType)
 
@@ -108,18 +112,18 @@ private[protobuf] case class ProtobufDataToCatalyst(
     val binary = input.asInstanceOf[Array[Byte]]
     try {
       result = DynamicMessage.parseFrom(messageDescriptor, binary)
-      val unknownFields = result.getUnknownFields
-      if (!unknownFields.asMap().isEmpty) {
-        unknownFields.asMap().keySet().asScala.map { number =>
-          {
-            if (fieldsNumbers.contains(number)) {
-              return handleException(
-                new Throwable(s"Type mismatch encountered for field:" +
-                  s" ${messageDescriptor.getFields.get(number)}"))
-            }
-          }
-        }
+      // If the Java class is available, it is likely more efficient to parse 
with it than using
+      // DynamicMessage. Can consider it in the future if parsing overhead is 
noticeable.
+
+      
result.getUnknownFields.asMap().keySet().asScala.find(fieldsNumbers.contains(_))
 match {
+        case Some(number) =>
+          // Unknown fields contain a field with same number as a known field. 
Must be due to
+          // mismatch of schema between writer and reader here.
+          throw new IllegalArgumentException(s"Type mismatch encountered for 
field:" +
+              s" ${messageDescriptor.getFields.get(number)}")
+        case None =>
       }
+
       val deserialized = deserializer.deserialize(result)
       assert(
         deserialized.isDefined,
diff --git 
a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/functions.scala
 
b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/functions.scala
index 283d1ca8c41..af30de40dad 100644
--- 
a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/functions.scala
+++ 
b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/functions.scala
@@ -33,20 +33,21 @@ object functions {
    *
    * @param data
    *   the binary column.
-   * @param descFilePath
-   *   the protobuf descriptor in Message GeneratedMessageV3 format.
    * @param messageName
    *   the protobuf message name to look for in descriptorFile.
+   * @param descFilePath
+   *   the protobuf descriptor in Message GeneratedMessageV3 format.
    * @since 3.4.0
    */
   @Experimental
   def from_protobuf(
       data: Column,
-      descFilePath: String,
       messageName: String,
+      descFilePath: String,
       options: java.util.Map[String, String]): Column = {
     new Column(
-      ProtobufDataToCatalyst(data.expr, descFilePath, messageName, 
options.asScala.toMap))
+      ProtobufDataToCatalyst(data.expr, messageName, Some(descFilePath), 
options.asScala.toMap)
+    )
   }
 
   /**
@@ -57,15 +58,34 @@ object functions {
    *
    * @param data
    *   the binary column.
-   * @param descFilePath
-   *   the protobuf descriptor in Message GeneratedMessageV3 format.
    * @param messageName
    *   the protobuf MessageName to look for in descriptorFile.
+   * @param descFilePath
+   *   the protobuf descriptor in Message GeneratedMessageV3 format.
    * @since 3.4.0
    */
   @Experimental
-  def from_protobuf(data: Column, descFilePath: String, messageName: String): 
Column = {
-    new Column(ProtobufDataToCatalyst(data.expr, descFilePath, messageName, 
Map.empty))
+  def from_protobuf(data: Column, messageName: String, descFilePath: String): 
Column = {
+    new Column(ProtobufDataToCatalyst(data.expr, messageName, descFilePath = 
Some(descFilePath)))
+    // TODO: Add an option for user to provide descriptor file content as a 
buffer. This
+    //       gives flexibility in how the content is fetched.
+  }
+
+  /**
+   * Converts a binary column of Protobuf format into its corresponding 
catalyst value. The
+   * specified schema must match actual schema of the read data, otherwise the 
behavior is
+   * undefined: it may fail or return arbitrary result. To deserialize the 
data with a compatible
+   * and evolved schema, the expected Protobuf schema can be set via the 
option protoSchema.
+   *
+   * @param data
+   *   the binary column.
+   * @param messageClassName
+   *   The Protobuf class name. E.g. 
<code>org.spark.examples.protobuf.ExampleEvent</code>.
+   * @since 3.4.0
+   */
+  @Experimental
+  def from_protobuf(data: Column, messageClassName: String): Column = {
+    new Column(ProtobufDataToCatalyst(data.expr, messageClassName))
   }
 
   /**
@@ -73,14 +93,28 @@ object functions {
    *
    * @param data
    *   the data column.
-   * @param descFilePath
-   *   the protobuf descriptor in Message GeneratedMessageV3 format.
    * @param messageName
    *   the protobuf MessageName to look for in descriptorFile.
+   * @param descFilePath
+   *   the protobuf descriptor in Message GeneratedMessageV3 format.
+   * @since 3.4.0
+   */
+  @Experimental
+  def to_protobuf(data: Column, messageName: String, descFilePath: String): 
Column = {
+    new Column(CatalystDataToProtobuf(data.expr, messageName, 
Some(descFilePath)))
+  }
+
+  /**
+   * Converts a column into binary of protobuf format.
+   *
+   * @param data
+   *   the data column.
+   * @param messageClassName
+   *   The Protobuf class name. E.g. 
<code>org.spark.examples.protobuf.ExampleEvent</code>.
    * @since 3.4.0
    */
   @Experimental
-  def to_protobuf(data: Column, descFilePath: String, messageName: String): 
Column = {
-    new Column(CatalystDataToProtobuf(data.expr, descFilePath, messageName))
+  def to_protobuf(data: Column, messageClassName: String): Column = {
+    new Column(CatalystDataToProtobuf(data.expr, messageClassName))
   }
 }
diff --git 
a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufUtils.scala
 
b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufUtils.scala
index 5ad043142a2..fa2ec9b7cd4 100644
--- 
a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufUtils.scala
+++ 
b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufUtils.scala
@@ -22,13 +22,14 @@ import java.util.Locale
 
 import scala.collection.JavaConverters._
 
-import com.google.protobuf.{DescriptorProtos, Descriptors, 
InvalidProtocolBufferException}
+import com.google.protobuf.{DescriptorProtos, Descriptors, 
InvalidProtocolBufferException, Message}
 import com.google.protobuf.Descriptors.{Descriptor, FieldDescriptor}
 
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.internal.SQLConf
 import 
org.apache.spark.sql.protobuf.utils.SchemaConverters.IncompatibleSchemaException
 import org.apache.spark.sql.types._
+import org.apache.spark.util.Utils
 
 private[sql] object ProtobufUtils extends Logging {
 
@@ -132,23 +133,63 @@ private[sql] object ProtobufUtils extends Logging {
     }
   }
 
-  def buildDescriptor(descFilePath: String, messageName: String): Descriptor = 
{
-    val fileDescriptor: Descriptors.FileDescriptor = 
parseFileDescriptor(descFilePath)
-    var result: Descriptors.Descriptor = null;
+  /**
+   * Builds Protobuf message descriptor either from the Java class or from 
serialized descriptor
+   * read from the file.
+   * @param messageName
+   *  Protobuf message name or Java class name.
+   * @param descFilePathOpt
+   *  When the file name set, the descriptor and it's dependencies are read 
from the file. Other
+   *  the `messageName` is treated as Java class name.
+   * @return
+   */
+  def buildDescriptor(messageName: String, descFilePathOpt: Option[String]): 
Descriptor = {
+    descFilePathOpt match {
+      case Some(filePath) => buildDescriptor(descFilePath = filePath, 
messageName)
+      case None => buildDescriptorFromJavaClass(messageName)
+    }
+  }
 
-    for (descriptor <- fileDescriptor.getMessageTypes.asScala) {
-      if (descriptor.getName().equals(messageName)) {
-        result = descriptor
-      }
+  /**
+   *  Loads the given protobuf class and returns Protobuf descriptor for it.
+   */
+  def buildDescriptorFromJavaClass(protobufClassName: String): Descriptor = {
+    val protobufClass = try {
+      Utils.classForName(protobufClassName)
+    } catch {
+      case _: ClassNotFoundException =>
+        val hasDots = protobufClassName.contains(".")
+        throw new IllegalArgumentException(
+          s"Could not load Protobuf class with name '$protobufClassName'" +
+          (if (hasDots) "" else ". Ensure the class name includes package 
prefix.")
+        )
+    }
+
+    if (!classOf[Message].isAssignableFrom(protobufClass)) {
+      throw new IllegalArgumentException(s"$protobufClassName is not a 
Protobuf message type")
+      // TODO: Need to support V2. This might work with V2 classes too.
+    }
+
+    // Extract the descriptor from Protobuf message.
+    protobufClass
+      .getDeclaredMethod("getDescriptor")
+      .invoke(null)
+      .asInstanceOf[Descriptor]
+  }
+
+  def buildDescriptor(descFilePath: String, messageName: String): Descriptor = 
{
+    val descriptor = 
parseFileDescriptor(descFilePath).getMessageTypes.asScala.find { desc =>
+      desc.getName == messageName || desc.getFullName == messageName
     }
 
-    if (null == result) {
-      throw new RuntimeException("Unable to locate Message '" + messageName + 
"' in Descriptor");
+    descriptor match {
+      case Some(d) => d
+      case None =>
+        throw new RuntimeException(s"Unable to locate Message '$messageName' 
in Descriptor")
     }
-    result
   }
 
-  def parseFileDescriptor(descFilePath: String): Descriptors.FileDescriptor = {
+  private def parseFileDescriptor(descFilePath: String): 
Descriptors.FileDescriptor = {
     var fileDescriptorSet: DescriptorProtos.FileDescriptorSet = null
     try {
       val dscFile = new BufferedInputStream(new FileInputStream(descFilePath))
diff --git 
a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/SchemaConverters.scala
 
b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/SchemaConverters.scala
index e385b816abe..4fca06fb5d8 100644
--- 
a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/SchemaConverters.scala
+++ 
b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/SchemaConverters.scala
@@ -66,6 +66,10 @@ object SchemaConverters {
         Some(DayTimeIntervalType.defaultConcreteType)
       case MESSAGE if fd.getMessageType.getName == "Timestamp" =>
         Some(TimestampType)
+        // FIXME: Is the above accurate? Users can have protos named 
"Timestamp" but are not
+        //        expected to be TimestampType in Spark. How about verifying 
fields?
+        //        Same for "Duration". Only the Timestamp & Duration protos 
defined in
+        //        google.protobuf package should default to corresponding 
Catalylist types.
       case MESSAGE if fd.isRepeated && 
fd.getMessageType.getOptions.hasMapEntry =>
         var keyType: DataType = NullType
         var valueType: DataType = NullType
diff --git 
a/connector/protobuf/src/test/resources/protobuf/catalyst_types.proto 
b/connector/protobuf/src/test/resources/protobuf/catalyst_types.proto
index 54e6bc18df1..1deb193438c 100644
--- a/connector/protobuf/src/test/resources/protobuf/catalyst_types.proto
+++ b/connector/protobuf/src/test/resources/protobuf/catalyst_types.proto
@@ -19,9 +19,11 @@
 
 syntax = "proto3";
 
-package org.apache.spark.sql.protobuf;
+package org.apache.spark.sql.protobuf.protos;
 option java_outer_classname = "CatalystTypes";
 
+// TODO: import one or more protobuf files.
+
 message BooleanMsg {
   bool bool_type = 1;
 }
diff --git 
a/connector/protobuf/src/test/resources/protobuf/functions_suite.proto 
b/connector/protobuf/src/test/resources/protobuf/functions_suite.proto
index f38c041b799..60f8c262141 100644
--- a/connector/protobuf/src/test/resources/protobuf/functions_suite.proto
+++ b/connector/protobuf/src/test/resources/protobuf/functions_suite.proto
@@ -20,7 +20,7 @@
 
 syntax = "proto3";
 
-package org.apache.spark.sql.protobuf;
+package org.apache.spark.sql.protobuf.protos;
 
 option java_outer_classname = "SimpleMessageProtos";
 
@@ -119,7 +119,7 @@ message SimpleMessageEnum {
   string key = 1;
   string value = 2;
   enum NestedEnum {
-    ESTED_NOTHING = 0;
+    ESTED_NOTHING = 0; // TODO: Fix the name.
     NESTED_FIRST = 1;
     NESTED_SECOND = 2;
   }
diff --git a/connector/protobuf/src/test/resources/protobuf/serde_suite.proto 
b/connector/protobuf/src/test/resources/protobuf/serde_suite.proto
index 1e3065259aa..a7459213a87 100644
--- a/connector/protobuf/src/test/resources/protobuf/serde_suite.proto
+++ b/connector/protobuf/src/test/resources/protobuf/serde_suite.proto
@@ -20,11 +20,11 @@
 
 syntax = "proto3";
 
-package org.apache.spark.sql.protobuf;
-option java_outer_classname = "SimpleMessageProtos";
+package org.apache.spark.sql.protobuf.protos;
+option java_outer_classname = "SerdeSuiteProtos";
 
 /* Clean Message*/
-message BasicMessage {
+message SerdeBasicMessage {
   Foo foo = 1;
 }
 
diff --git 
a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufCatalystDataConversionSuite.scala
 
b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufCatalystDataConversionSuite.scala
index b730ebb4fea..19774a2ad07 100644
--- 
a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufCatalystDataConversionSuite.scala
+++ 
b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufCatalystDataConversionSuite.scala
@@ -24,6 +24,7 @@ import org.apache.spark.sql.{RandomDataGenerator, Row}
 import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, 
NoopFilters, OrderedFilters, StructFilters}
 import org.apache.spark.sql.catalyst.expressions.{ExpressionEvalHelper, 
GenericInternalRow, Literal}
 import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, 
GenericArrayData, MapData}
+import org.apache.spark.sql.protobuf.protos.CatalystTypes.BytesMsg
 import org.apache.spark.sql.protobuf.utils.{ProtobufUtils, SchemaConverters}
 import org.apache.spark.sql.sources.{EqualTo, Not}
 import org.apache.spark.sql.test.SharedSparkSession
@@ -35,18 +36,32 @@ class ProtobufCatalystDataConversionSuite
     with SharedSparkSession
     with ExpressionEvalHelper {
 
-  private def checkResult(
+  private val testFileDesc = 
testFile("protobuf/catalyst_types.desc").replace("file:/", "/")
+  private val javaClassNamePrefix = 
"org.apache.spark.sql.protobuf.protos.CatalystTypes$"
+
+  private def checkResultWithEval(
       data: Literal,
       descFilePath: String,
       messageName: String,
       expected: Any): Unit = {
-    checkEvaluation(
-      ProtobufDataToCatalyst(
-        CatalystDataToProtobuf(data, descFilePath, messageName),
-        descFilePath,
-        messageName,
-        Map.empty),
-      prepareExpectedResult(expected))
+
+    withClue("(Eval check with Java class name)") {
+      val className = s"$javaClassNamePrefix$messageName"
+      checkEvaluation(
+        ProtobufDataToCatalyst(
+          CatalystDataToProtobuf(data, className),
+          className,
+          descFilePath = None),
+        prepareExpectedResult(expected))
+    }
+    withClue("(Eval check with descriptor file)") {
+      checkEvaluation(
+        ProtobufDataToCatalyst(
+          CatalystDataToProtobuf(data, messageName, Some(descFilePath)),
+          messageName,
+          descFilePath = Some(descFilePath)),
+        prepareExpectedResult(expected))
+    }
   }
 
   protected def checkUnsupportedRead(
@@ -55,10 +70,11 @@ class ProtobufCatalystDataConversionSuite
       actualSchema: String,
       badSchema: String): Unit = {
 
-    val binary = CatalystDataToProtobuf(data, descFilePath, actualSchema)
+    val binary = CatalystDataToProtobuf(data, actualSchema, Some(descFilePath))
 
     intercept[Exception] {
-      ProtobufDataToCatalyst(binary, descFilePath, badSchema, Map("mode" -> 
"FAILFAST")).eval()
+      ProtobufDataToCatalyst(binary, badSchema, Some(descFilePath), Map("mode" 
-> "FAILFAST"))
+        .eval()
     }
 
     val expected = {
@@ -73,7 +89,7 @@ class ProtobufCatalystDataConversionSuite
     }
 
     checkEvaluation(
-      ProtobufDataToCatalyst(binary, descFilePath, badSchema, Map("mode" -> 
"PERMISSIVE")),
+      ProtobufDataToCatalyst(binary, badSchema, Some(descFilePath), Map("mode" 
-> "PERMISSIVE")),
       expected)
   }
 
@@ -99,26 +115,32 @@ class ProtobufCatalystDataConversionSuite
     StructType(StructField("bytes_type", BinaryType, nullable = true) :: Nil),
     StructType(StructField("string_type", StringType, nullable = true) :: Nil))
 
-  private val catalystTypesToProtoMessages: Map[DataType, String] = Map(
-    IntegerType -> "IntegerMsg",
-    DoubleType -> "DoubleMsg",
-    FloatType -> "FloatMsg",
-    BinaryType -> "BytesMsg",
-    StringType -> "StringMsg")
+  private val catalystTypesToProtoMessages: Map[DataType, (String, Any)] = Map(
+    IntegerType -> ("IntegerMsg", 0),
+    DoubleType -> ("DoubleMsg", 0.0d),
+    FloatType -> ("FloatMsg", 0.0f),
+    BinaryType -> ("BytesMsg", ByteString.empty().toByteArray),
+    StringType -> ("StringMsg", ""))
 
   testingTypes.foreach { dt =>
     val seed = 1 + scala.util.Random.nextInt((1024 - 1) + 1)
-    val filePath = testFile("protobuf/catalyst_types.desc").replace("file:/", 
"/")
     test(s"single $dt with seed $seed") {
+
+      val (messageName, defaultValue) = 
catalystTypesToProtoMessages(dt.fields(0).dataType)
+
       val rand = new scala.util.Random(seed)
-      val data = RandomDataGenerator.forType(dt, rand = rand).get.apply()
+      val generator = RandomDataGenerator.forType(dt, rand = rand).get
+      var data = generator()
+      while (data.asInstanceOf[Row].get(0) == defaultValue) // Do not use 
default values, since
+        data = generator()                                  // from_protobuf() 
returns null in v3.
+
       val converter = CatalystTypeConverters.createToCatalystConverter(dt)
       val input = Literal.create(converter(data), dt)
 
-      checkResult(
+      checkResultWithEval(
         input,
-        filePath,
-        catalystTypesToProtoMessages(dt.fields(0).dataType),
+        testFileDesc,
+        messageName,
         input.eval())
     }
   }
@@ -137,6 +159,15 @@ class ProtobufCatalystDataConversionSuite
 
     val dynMsg = DynamicMessage.parseFrom(descriptor, data.toByteArray)
     val deserialized = deserializer.deserialize(dynMsg)
+
+    // Verify Java class deserializer matches with descriptor based serializer.
+    val javaDescriptor = ProtobufUtils
+      .buildDescriptorFromJavaClass(s"$javaClassNamePrefix$messageName")
+    assert(dataType == SchemaConverters.toSqlType(javaDescriptor).dataType)
+    val javaDeserialized = new ProtobufDeserializer(javaDescriptor, dataType, 
filters)
+      .deserialize(DynamicMessage.parseFrom(javaDescriptor, data.toByteArray))
+    assert(deserialized == javaDeserialized)
+
     expected match {
       case None => assert(deserialized.isEmpty)
       case Some(d) =>
@@ -145,7 +176,6 @@ class ProtobufCatalystDataConversionSuite
   }
 
   test("Handle unsupported input of message type") {
-    val testFileDesc = 
testFile("protobuf/catalyst_types.desc").replace("file:/", "/")
     val actualSchema = StructType(
       Seq(
         StructField("col_0", StringType, nullable = false),
@@ -165,7 +195,6 @@ class ProtobufCatalystDataConversionSuite
 
   test("filter push-down to Protobuf deserializer") {
 
-    val testFileDesc = 
testFile("protobuf/catalyst_types.desc").replace("file:/", "/")
     val sqlSchema = new StructType()
       .add("name", "string")
       .add("age", "int")
@@ -196,17 +225,23 @@ class ProtobufCatalystDataConversionSuite
 
   test("ProtobufDeserializer with binary type") {
 
-    val testFileDesc = 
testFile("protobuf/catalyst_types.desc").replace("file:/", "/")
     val bb = java.nio.ByteBuffer.wrap(Array[Byte](97, 48, 53))
 
-    val descriptor = ProtobufUtils.buildDescriptor(testFileDesc, "BytesMsg")
-
-    val dynamicMessage = DynamicMessage
-      .newBuilder(descriptor)
-      .setField(descriptor.findFieldByName("bytes_type"), 
ByteString.copyFrom(bb))
+    val bytesProto = BytesMsg
+      .newBuilder()
+      .setBytesType(ByteString.copyFrom(bb))
       .build()
 
     val expected = InternalRow(Array[Byte](97, 48, 53))
-    checkDeserialization(testFileDesc, "BytesMsg", dynamicMessage, 
Some(expected))
+    checkDeserialization(testFileDesc, "BytesMsg", bytesProto, Some(expected))
+  }
+
+  test("Full names for message using descriptor file") {
+    val withShortName = ProtobufUtils.buildDescriptor(testFileDesc, "BytesMsg")
+    assert(withShortName.findFieldByName("bytes_type") != null)
+
+    val withFullName = ProtobufUtils.buildDescriptor(
+      testFileDesc, "org.apache.spark.sql.protobuf.BytesMsg")
+    assert(withFullName.findFieldByName("bytes_type") != null)
   }
 }
diff --git 
a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala
 
b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala
index 4e9bc1c1c28..72280fb0d9e 100644
--- 
a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala
+++ 
b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala
@@ -23,8 +23,10 @@ import scala.collection.JavaConverters._
 
 import com.google.protobuf.{ByteString, DynamicMessage}
 
-import org.apache.spark.sql.{QueryTest, Row}
+import org.apache.spark.sql.{Column, QueryTest, Row}
 import org.apache.spark.sql.functions.{lit, struct}
+import 
org.apache.spark.sql.protobuf.protos.SimpleMessageProtos.SimpleMessageRepeated
+import 
org.apache.spark.sql.protobuf.protos.SimpleMessageProtos.SimpleMessageRepeated.NestedEnum
 import org.apache.spark.sql.protobuf.utils.ProtobufUtils
 import 
org.apache.spark.sql.protobuf.utils.SchemaConverters.IncompatibleSchemaException
 import org.apache.spark.sql.test.SharedSparkSession
@@ -35,6 +37,39 @@ class ProtobufFunctionsSuite extends QueryTest with 
SharedSparkSession with Seri
   import testImplicits._
 
   val testFileDesc = 
testFile("protobuf/functions_suite.desc").replace("file:/", "/")
+  private val javaClassNamePrefix = 
"org.apache.spark.sql.protobuf.protos.SimpleMessageProtos$"
+
+  /**
+   * Runs the given closure twice. Once with descriptor file and second time 
with Java class name.
+   */
+  private def checkWithFileAndClassName(messageName: String)(
+    fn: (String, Option[String]) => Unit): Unit = {
+      withClue("(With descriptor file)") {
+        fn(messageName, Some(testFileDesc))
+      }
+      withClue("(With Java class name)") {
+        fn(s"$javaClassNamePrefix$messageName", None)
+      }
+  }
+
+  // A wrapper to invoke the right variable of from_protobuf() depending on 
arguments.
+  private def from_protobuf_wrapper(
+    col: Column, messageName: String, descFilePathOpt: Option[String]): Column 
= {
+    descFilePathOpt match {
+      case Some(descFilePath) => functions.from_protobuf(col, messageName, 
descFilePath)
+      case None => functions.from_protobuf(col, messageName)
+    }
+  }
+
+  // A wrapper to invoke the right variable of to_protobuf() depending on 
arguments.
+  private def to_protobuf_wrapper(
+    col: Column, messageName: String, descFilePathOpt: Option[String]): Column 
= {
+    descFilePathOpt match {
+      case Some(descFilePath) => functions.to_protobuf(col, messageName, 
descFilePath)
+      case None => functions.to_protobuf(col, messageName)
+    }
+  }
+
 
   test("roundtrip in to_protobuf and from_protobuf - struct") {
     val df = spark
@@ -56,44 +91,45 @@ class ProtobufFunctionsSuite extends QueryTest with 
SharedSparkSession with Seri
         
lit(1202.00).cast(org.apache.spark.sql.types.FloatType).as("float_value"),
         lit(true).as("bool_value"),
         lit("0".getBytes).as("bytes_value")).as("SimpleMessage"))
-    val protoStructDF = df.select(
-      functions.to_protobuf($"SimpleMessage", testFileDesc, 
"SimpleMessage").as("proto"))
-    val actualDf = protoStructDF.select(
-      functions.from_protobuf($"proto", testFileDesc, 
"SimpleMessage").as("proto.*"))
-    checkAnswer(actualDf, df)
+
+    checkWithFileAndClassName("SimpleMessage") {
+      case (name, descFilePathOpt) =>
+        val protoStructDF = df.select(
+          to_protobuf_wrapper($"SimpleMessage", name, 
descFilePathOpt).as("proto"))
+        val actualDf = protoStructDF.select(
+          from_protobuf_wrapper($"proto", name, descFilePathOpt).as("proto.*"))
+        checkAnswer(actualDf, df)
+    }
   }
 
   test("roundtrip in from_protobuf and to_protobuf - Repeated") {
-    val descriptor = ProtobufUtils.buildDescriptor(testFileDesc, 
"SimpleMessageRepeated")
 
-    val dynamicMessage = DynamicMessage
-      .newBuilder(descriptor)
-      .setField(descriptor.findFieldByName("key"), "key")
-      .setField(descriptor.findFieldByName("value"), "value")
-      .addRepeatedField(descriptor.findFieldByName("rbool_value"), false)
-      .addRepeatedField(descriptor.findFieldByName("rbool_value"), true)
-      .addRepeatedField(descriptor.findFieldByName("rdouble_value"), 
1092092.654d)
-      .addRepeatedField(descriptor.findFieldByName("rdouble_value"), 
1092093.654d)
-      .addRepeatedField(descriptor.findFieldByName("rfloat_value"), 10903.0f)
-      .addRepeatedField(descriptor.findFieldByName("rfloat_value"), 10902.0f)
-      .addRepeatedField(
-        descriptor.findFieldByName("rnested_enum"),
-        
descriptor.findEnumTypeByName("NestedEnum").findValueByName("ESTED_NOTHING"))
-      .addRepeatedField(
-        descriptor.findFieldByName("rnested_enum"),
-        
descriptor.findEnumTypeByName("NestedEnum").findValueByName("NESTED_FIRST"))
+    val protoMessage = SimpleMessageRepeated
+      .newBuilder()
+      .setKey("key")
+      .setValue("value")
+      .addRboolValue(false)
+      .addRboolValue(true)
+      .addRdoubleValue(1092092.654d)
+      .addRdoubleValue(1092093.654d)
+      .addRfloatValue(10903.0f)
+      .addRfloatValue(10902.0f)
+      .addRnestedEnum(NestedEnum.ESTED_NOTHING)
+      .addRnestedEnum(NestedEnum.NESTED_FIRST)
       .build()
 
-    val df = Seq(dynamicMessage.toByteArray).toDF("value")
-    val fromProtoDF = df.select(
-      functions.from_protobuf($"value", testFileDesc, 
"SimpleMessageRepeated").as("value_from"))
-    val toProtoDF = fromProtoDF.select(
-      functions.to_protobuf($"value_from", testFileDesc, 
"SimpleMessageRepeated").as("value_to"))
-    val toFromProtoDF = toProtoDF.select(
-      functions
-        .from_protobuf($"value_to", testFileDesc, "SimpleMessageRepeated")
-        .as("value_to_from"))
-    checkAnswer(fromProtoDF.select($"value_from.*"), 
toFromProtoDF.select($"value_to_from.*"))
+    val df = Seq(protoMessage.toByteArray).toDF("value")
+
+    checkWithFileAndClassName("SimpleMessageRepeated") {
+      case (name, descFilePathOpt) =>
+        val fromProtoDF = df.select(
+          from_protobuf_wrapper($"value", name, 
descFilePathOpt).as("value_from"))
+        val toProtoDF = fromProtoDF.select(
+          to_protobuf_wrapper($"value_from", name, 
descFilePathOpt).as("value_to"))
+        val toFromProtoDF = toProtoDF.select(
+          from_protobuf_wrapper($"value_to", name, 
descFilePathOpt).as("value_to_from"))
+        checkAnswer(fromProtoDF.select($"value_from.*"), 
toFromProtoDF.select($"value_to_from.*"))
+    }
   }
 
   test("roundtrip in from_protobuf and to_protobuf - Repeated Message Once") {
@@ -120,13 +156,17 @@ class ProtobufFunctionsSuite extends QueryTest with 
SharedSparkSession with Seri
       .build()
 
     val df = Seq(dynamicMessage.toByteArray).toDF("value")
-    val fromProtoDF = df.select(
-      functions.from_protobuf($"value", testFileDesc, 
"RepeatedMessage").as("value_from"))
-    val toProtoDF = fromProtoDF.select(
-      functions.to_protobuf($"value_from", testFileDesc, 
"RepeatedMessage").as("value_to"))
-    val toFromProtoDF = toProtoDF.select(
-      functions.from_protobuf($"value_to", testFileDesc, 
"RepeatedMessage").as("value_to_from"))
-    checkAnswer(fromProtoDF.select($"value_from.*"), 
toFromProtoDF.select($"value_to_from.*"))
+
+    checkWithFileAndClassName("RepeatedMessage") {
+      case (name, descFilePathOpt) =>
+        val fromProtoDF = df.select(
+          from_protobuf_wrapper($"value", name, 
descFilePathOpt).as("value_from"))
+        val toProtoDF = fromProtoDF.select(
+          to_protobuf_wrapper($"value_from", name, 
descFilePathOpt).as("value_to"))
+        val toFromProtoDF = toProtoDF.select(
+          from_protobuf_wrapper($"value_to", name, 
descFilePathOpt).as("value_to_from"))
+        checkAnswer(fromProtoDF.select($"value_from.*"), 
toFromProtoDF.select($"value_to_from.*"))
+    }
   }
 
   test("roundtrip in from_protobuf and to_protobuf - Repeated Message Twice") {
@@ -167,13 +207,17 @@ class ProtobufFunctionsSuite extends QueryTest with 
SharedSparkSession with Seri
       .build()
 
     val df = Seq(dynamicMessage.toByteArray).toDF("value")
-    val fromProtoDF = df.select(
-      functions.from_protobuf($"value", testFileDesc, 
"RepeatedMessage").as("value_from"))
-    val toProtoDF = fromProtoDF.select(
-      functions.to_protobuf($"value_from", testFileDesc, 
"RepeatedMessage").as("value_to"))
-    val toFromProtoDF = toProtoDF.select(
-      functions.from_protobuf($"value_to", testFileDesc, 
"RepeatedMessage").as("value_to_from"))
-    checkAnswer(fromProtoDF.select($"value_from.*"), 
toFromProtoDF.select($"value_to_from.*"))
+
+    checkWithFileAndClassName("RepeatedMessage") {
+      case (name, descFilePathOpt) =>
+        val fromProtoDF = df.select(
+          from_protobuf_wrapper($"value", name, 
descFilePathOpt).as("value_from"))
+        val toProtoDF = fromProtoDF.select(
+          to_protobuf_wrapper($"value_from", name, 
descFilePathOpt).as("value_to"))
+        val toFromProtoDF = toProtoDF.select(
+          from_protobuf_wrapper($"value_to", name, 
descFilePathOpt).as("value_to_from"))
+        checkAnswer(fromProtoDF.select($"value_from.*"), 
toFromProtoDF.select($"value_to_from.*"))
+    }
   }
 
   test("roundtrip in from_protobuf and to_protobuf - Map") {
@@ -257,13 +301,17 @@ class ProtobufFunctionsSuite extends QueryTest with 
SharedSparkSession with Seri
       .build()
 
     val df = Seq(dynamicMessage.toByteArray).toDF("value")
-    val fromProtoDF = df.select(
-      functions.from_protobuf($"value", testFileDesc, 
"SimpleMessageMap").as("value_from"))
-    val toProtoDF = fromProtoDF.select(
-      functions.to_protobuf($"value_from", testFileDesc, 
"SimpleMessageMap").as("value_to"))
-    val toFromProtoDF = toProtoDF.select(
-      functions.from_protobuf($"value_to", testFileDesc, 
"SimpleMessageMap").as("value_to_from"))
-    checkAnswer(fromProtoDF.select($"value_from.*"), 
toFromProtoDF.select($"value_to_from.*"))
+
+    checkWithFileAndClassName("SimpleMessageMap") {
+      case (name, descFilePathOpt) =>
+        val fromProtoDF = df.select(
+          from_protobuf_wrapper($"value", name, 
descFilePathOpt).as("value_from"))
+        val toProtoDF = fromProtoDF.select(
+          to_protobuf_wrapper($"value_from", name, 
descFilePathOpt).as("value_to"))
+        val toFromProtoDF = toProtoDF.select(
+          from_protobuf_wrapper($"value_to", name, 
descFilePathOpt).as("value_to_from"))
+        checkAnswer(fromProtoDF.select($"value_from.*"), 
toFromProtoDF.select($"value_to_from.*"))
+    }
   }
 
   test("roundtrip in from_protobuf and to_protobuf - Enum") {
@@ -289,13 +337,17 @@ class ProtobufFunctionsSuite extends QueryTest with 
SharedSparkSession with Seri
       .build()
 
     val df = Seq(dynamicMessage.toByteArray).toDF("value")
-    val fromProtoDF = df.select(
-      functions.from_protobuf($"value", testFileDesc, 
"SimpleMessageEnum").as("value_from"))
-    val toProtoDF = fromProtoDF.select(
-      functions.to_protobuf($"value_from", testFileDesc, 
"SimpleMessageEnum").as("value_to"))
-    val toFromProtoDF = toProtoDF.select(
-      functions.from_protobuf($"value_to", testFileDesc, 
"SimpleMessageEnum").as("value_to_from"))
-    checkAnswer(fromProtoDF.select($"value_from.*"), 
toFromProtoDF.select($"value_to_from.*"))
+
+    checkWithFileAndClassName("SimpleMessageEnum") {
+      case (name, descFilePathOpt) =>
+        val fromProtoDF = df.select(
+          from_protobuf_wrapper($"value", name, 
descFilePathOpt).as("value_from"))
+        val toProtoDF = fromProtoDF.select(
+          to_protobuf_wrapper($"value_from", name, 
descFilePathOpt).as("value_to"))
+        val toFromProtoDF = toProtoDF.select(
+          from_protobuf_wrapper($"value_to", name, 
descFilePathOpt).as("value_to_from"))
+        checkAnswer(fromProtoDF.select($"value_from.*"), 
toFromProtoDF.select($"value_to_from.*"))
+    }
   }
 
   test("roundtrip in from_protobuf and to_protobuf - Multiple Message") {
@@ -320,13 +372,17 @@ class ProtobufFunctionsSuite extends QueryTest with 
SharedSparkSession with Seri
       .build()
 
     val df = Seq(dynamicMessage.toByteArray).toDF("value")
-    val fromProtoDF = df.select(
-      functions.from_protobuf($"value", testFileDesc, 
"MultipleExample").as("value_from"))
-    val toProtoDF = fromProtoDF.select(
-      functions.to_protobuf($"value_from", testFileDesc, 
"MultipleExample").as("value_to"))
-    val toFromProtoDF = toProtoDF.select(
-      functions.from_protobuf($"value_to", testFileDesc, 
"MultipleExample").as("value_to_from"))
-    checkAnswer(fromProtoDF.select($"value_from.*"), 
toFromProtoDF.select($"value_to_from.*"))
+
+    checkWithFileAndClassName("MultipleExample") {
+      case (name, descFilePathOpt) =>
+        val fromProtoDF = df.select(
+          from_protobuf_wrapper($"value", name, 
descFilePathOpt).as("value_from"))
+        val toProtoDF = fromProtoDF.select(
+          to_protobuf_wrapper($"value_from", name, 
descFilePathOpt).as("value_to"))
+        val toFromProtoDF = toProtoDF.select(
+          from_protobuf_wrapper($"value_to", name, 
descFilePathOpt).as("value_to_from"))
+        checkAnswer(fromProtoDF.select($"value_from.*"), 
toFromProtoDF.select($"value_to_from.*"))
+    }
   }
 
   test("Handle recursive fields in Protobuf schema, A->B->A") {
@@ -352,15 +408,17 @@ class ProtobufFunctionsSuite extends QueryTest with 
SharedSparkSession with Seri
 
     val df = Seq(messageB.toByteArray).toDF("messageB")
 
-    val e = intercept[IncompatibleSchemaException] {
-      df.select(
-        functions.from_protobuf($"messageB", testFileDesc, 
"recursiveB").as("messageFromProto"))
-        .show()
+    checkWithFileAndClassName("recursiveB") {
+      case (name, descFilePathOpt) =>
+        val e = intercept[IncompatibleSchemaException] {
+          df.select(
+            from_protobuf_wrapper($"messageB", name, 
descFilePathOpt).as("messageFromProto"))
+            .show()
+        }
+        assert(e.getMessage.contains(
+          "Found recursive reference in Protobuf schema, which can not be 
processed by Spark:"
+        ))
     }
-    val expectedMessage = s"""
-         |Found recursive reference in Protobuf schema, which can not be 
processed by Spark:
-         |org.apache.spark.sql.protobuf.recursiveB.messageA""".stripMargin
-    assert(e.getMessage == expectedMessage)
   }
 
   test("Handle recursive fields in Protobuf schema, C->D->Array(C)") {
@@ -386,16 +444,17 @@ class ProtobufFunctionsSuite extends QueryTest with 
SharedSparkSession with Seri
 
     val df = Seq(messageD.toByteArray).toDF("messageD")
 
-    val e = intercept[IncompatibleSchemaException] {
-      df.select(
-        functions.from_protobuf($"messageD", testFileDesc, 
"recursiveD").as("messageFromProto"))
-        .show()
+    checkWithFileAndClassName("recursiveD") {
+      case (name, descFilePathOpt) =>
+        val e = intercept[IncompatibleSchemaException] {
+          df.select(
+            from_protobuf_wrapper($"messageD", name, 
descFilePathOpt).as("messageFromProto"))
+            .show()
+        }
+        assert(e.getMessage.contains(
+          "Found recursive reference in Protobuf schema, which can not be 
processed by Spark:"
+        ))
     }
-    val expectedMessage =
-      s"""
-         |Found recursive reference in Protobuf schema, which can not be 
processed by Spark:
-         |org.apache.spark.sql.protobuf.recursiveD.messageC""".stripMargin
-    assert(e.getMessage == expectedMessage)
   }
 
   test("Handle extra fields : oldProducer -> newConsumer") {
@@ -411,17 +470,17 @@ class ProtobufFunctionsSuite extends QueryTest with 
SharedSparkSession with Seri
     val df = Seq(oldProducerMessage.toByteArray).toDF("oldProducerData")
     val fromProtoDf = df.select(
       functions
-        .from_protobuf($"oldProducerData", testFileDesc, "newConsumer")
+        .from_protobuf($"oldProducerData", "newConsumer", testFileDesc)
         .as("fromProto"))
 
     val toProtoDf = fromProtoDf.select(
       functions
-        .to_protobuf($"fromProto", testFileDesc, "newConsumer")
+        .to_protobuf($"fromProto", "newConsumer", testFileDesc)
         .as("toProto"))
 
     val toProtoDfToFromProtoDf = toProtoDf.select(
       functions
-        .from_protobuf($"toProto", testFileDesc, "newConsumer")
+        .from_protobuf($"toProto", "newConsumer", testFileDesc)
         .as("toProtoToFromProto"))
 
     val actualFieldNames =
@@ -452,7 +511,7 @@ class ProtobufFunctionsSuite extends QueryTest with 
SharedSparkSession with Seri
     val df = Seq(newProducerMessage.toByteArray).toDF("newProducerData")
     val fromProtoDf = df.select(
       functions
-        .from_protobuf($"newProducerData", testFileDesc, "oldConsumer")
+        .from_protobuf($"newProducerData", "oldConsumer", testFileDesc)
         .as("oldConsumerProto"))
 
     val expectedFieldNames = oldConsumer.getFields.asScala.map(f => f.getName)
@@ -481,8 +540,9 @@ class ProtobufFunctionsSuite extends QueryTest with 
SharedSparkSession with Seri
       )),
       schema
     )
+
     val toProtobuf = inputDf.select(
-      functions.to_protobuf($"requiredMsg", testFileDesc, "requiredMsg")
+      functions.to_protobuf($"requiredMsg", "requiredMsg", testFileDesc)
         .as("to_proto"))
 
     val binary = toProtobuf.take(1).toSeq(0).get(0).asInstanceOf[Array[Byte]]
@@ -498,7 +558,7 @@ class ProtobufFunctionsSuite extends QueryTest with 
SharedSparkSession with Seri
     assert(actualMessage.getField(messageDescriptor.findFieldByName("col_3")) 
== 0)
 
     val fromProtoDf = toProtobuf.select(
-      functions.from_protobuf($"to_proto", testFileDesc, "requiredMsg") as 
'from_proto)
+      functions.from_protobuf($"to_proto", "requiredMsg", testFileDesc) as 
'from_proto)
 
     assert(fromProtoDf.select("from_proto.key").take(1).toSeq(0).get(0)
       == inputDf.select("requiredMsg.key").take(1).toSeq(0).get(0))
@@ -526,16 +586,20 @@ class ProtobufFunctionsSuite extends QueryTest with 
SharedSparkSession with Seri
       .build()
 
     val df = Seq(basicMessage.toByteArray).toDF("value")
-    val resultFrom = df
-      .select(functions.from_protobuf($"value", testFileDesc, "BasicMessage") 
as 'sample)
-      .where("sample.string_value == \"slam\"")
 
-    val resultToFrom = resultFrom
-      .select(functions.to_protobuf($"sample", testFileDesc, "BasicMessage") 
as 'value)
-      .select(functions.from_protobuf($"value", testFileDesc, "BasicMessage") 
as 'sample)
-      .where("sample.string_value == \"slam\"")
+    checkWithFileAndClassName("BasicMessage") {
+      case (name, descFilePathOpt) =>
+        val resultFrom = df
+          .select(from_protobuf_wrapper($"value", name, descFilePathOpt) as 
'sample)
+          .where("sample.string_value == \"slam\"")
+
+        val resultToFrom = resultFrom
+          .select(to_protobuf_wrapper($"sample", name, descFilePathOpt) as 
'value)
+          .select(from_protobuf_wrapper($"value", name, descFilePathOpt) as 
'sample)
+          .where("sample.string_value == \"slam\"")
 
-    assert(resultFrom.except(resultToFrom).isEmpty)
+        assert(resultFrom.except(resultToFrom).isEmpty)
+    }
   }
 
   test("Handle TimestampType between to_protobuf and from_protobuf") {
@@ -556,22 +620,24 @@ class ProtobufFunctionsSuite extends QueryTest with 
SharedSparkSession with Seri
       schema
     )
 
-    val toProtoDf = inputDf
-      .select(functions.to_protobuf($"timeStampMsg", testFileDesc, 
"timeStampMsg") as 'to_proto)
+    checkWithFileAndClassName("timeStampMsg") {
+      case (name, descFilePathOpt) =>
+        val toProtoDf = inputDf
+          .select(to_protobuf_wrapper($"timeStampMsg", name, descFilePathOpt) 
as 'to_proto)
 
-    val fromProtoDf = toProtoDf
-      .select(functions.from_protobuf($"to_proto", testFileDesc, 
"timeStampMsg") as 'timeStampMsg)
-    fromProtoDf.show(truncate = false)
+        val fromProtoDf = toProtoDf
+          .select(from_protobuf_wrapper($"to_proto", name, descFilePathOpt) as 
'timeStampMsg)
 
-    val actualFields = fromProtoDf.schema.fields.toList
-    val expectedFields = inputDf.schema.fields.toList
+        val actualFields = fromProtoDf.schema.fields.toList
+        val expectedFields = inputDf.schema.fields.toList
 
-    assert(actualFields.size === expectedFields.size)
-    assert(actualFields === expectedFields)
-    assert(fromProtoDf.select("timeStampMsg.key").take(1).toSeq(0).get(0)
-      === inputDf.select("timeStampMsg.key").take(1).toSeq(0).get(0))
-    assert(fromProtoDf.select("timeStampMsg.stmp").take(1).toSeq(0).get(0)
-      === inputDf.select("timeStampMsg.stmp").take(1).toSeq(0).get(0))
+        assert(actualFields.size === expectedFields.size)
+        assert(actualFields === expectedFields)
+        assert(fromProtoDf.select("timeStampMsg.key").take(1).toSeq(0).get(0)
+          === inputDf.select("timeStampMsg.key").take(1).toSeq(0).get(0))
+        assert(fromProtoDf.select("timeStampMsg.stmp").take(1).toSeq(0).get(0)
+          === inputDf.select("timeStampMsg.stmp").take(1).toSeq(0).get(0))
+    }
   }
 
   test("Handle DayTimeIntervalType between to_protobuf and from_protobuf") {
@@ -595,21 +661,23 @@ class ProtobufFunctionsSuite extends QueryTest with 
SharedSparkSession with Seri
       schema
     )
 
-    val toProtoDf = inputDf
-      .select(functions.to_protobuf($"durationMsg", testFileDesc, 
"durationMsg") as 'to_proto)
+    checkWithFileAndClassName("durationMsg") {
+      case (name, descFilePathOpt) =>
+        val toProtoDf = inputDf
+          .select(to_protobuf_wrapper($"durationMsg", name, descFilePathOpt) 
as 'to_proto)
 
-    val fromProtoDf = toProtoDf
-      .select(functions.from_protobuf($"to_proto", testFileDesc, 
"durationMsg") as 'durationMsg)
+        val fromProtoDf = toProtoDf
+          .select(from_protobuf_wrapper($"to_proto", name, descFilePathOpt) as 
'durationMsg)
 
-    val actualFields = fromProtoDf.schema.fields.toList
-    val expectedFields = inputDf.schema.fields.toList
-
-    assert(actualFields.size === expectedFields.size)
-    assert(actualFields === expectedFields)
-    assert(fromProtoDf.select("durationMsg.key").take(1).toSeq(0).get(0)
-      === inputDf.select("durationMsg.key").take(1).toSeq(0).get(0))
-    assert(fromProtoDf.select("durationMsg.duration").take(1).toSeq(0).get(0)
-      === inputDf.select("durationMsg.duration").take(1).toSeq(0).get(0))
+        val actualFields = fromProtoDf.schema.fields.toList
+        val expectedFields = inputDf.schema.fields.toList
 
+        assert(actualFields.size === expectedFields.size)
+        assert(actualFields === expectedFields)
+        assert(fromProtoDf.select("durationMsg.key").take(1).toSeq(0).get(0)
+          === inputDf.select("durationMsg.key").take(1).toSeq(0).get(0))
+        
assert(fromProtoDf.select("durationMsg.duration").take(1).toSeq(0).get(0)
+          === inputDf.select("durationMsg.duration").take(1).toSeq(0).get(0))
+    }
   }
 }
diff --git 
a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufSerdeSuite.scala
 
b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufSerdeSuite.scala
index 37c59743e77..efc02524e68 100644
--- 
a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufSerdeSuite.scala
+++ 
b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufSerdeSuite.scala
@@ -36,6 +36,7 @@ class ProtobufSerdeSuite extends SharedSparkSession {
   import ProtoSerdeSuite.MatchType._
 
   val testFileDesc = testFile("protobuf/serde_suite.desc").replace("file:/", 
"/")
+  private val javaClassNamePrefix = 
"org.apache.spark.sql.protobuf.protos.SerdeSuiteProtos$"
 
   test("Test basic conversion") {
     withFieldMatchType { fieldMatch =>
@@ -96,7 +97,9 @@ class ProtobufSerdeSuite extends SharedSparkSession {
   }
 
   test("Fail to convert with deeply nested field type mismatch") {
-    val protoFile = ProtobufUtils.buildDescriptor(testFileDesc, 
"MissMatchTypeInDeepNested")
+    val protoFile = ProtobufUtils.buildDescriptorFromJavaClass(
+      s"${javaClassNamePrefix}MissMatchTypeInDeepNested"
+    )
     val catalyst = new StructType().add("top", CATALYST_STRUCT)
 
     withFieldMatchType { fieldMatch =>
@@ -105,8 +108,8 @@ class ProtobufSerdeSuite extends SharedSparkSession {
         Deserializer,
         fieldMatch,
         s"Cannot convert Protobuf field 'top.foo.bar' to SQL field 
'top.foo.bar' because schema " +
-          s"is incompatible (protoType = 
org.apache.spark.sql.protobuf.TypeMiss.bar " +
-          s"LABEL_OPTIONAL LONG INT64, sqlType = INT)".stripMargin,
+          s"is incompatible (protoType = 
org.apache.spark.sql.protobuf.protos.TypeMiss.bar " +
+          s"LABEL_OPTIONAL LONG INT64, sqlType = INT)",
         catalyst)
 
       assertFailedConversionMessage(
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index e5a48080e83..cc103e4ab00 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -716,8 +716,10 @@ object SparkProtobuf {
 
     dependencyOverrides += "com.google.protobuf" % "protobuf-java" % 
protoVersion,
 
-    (Compile / PB.targets) := Seq(
-      PB.gens.java -> (Compile / sourceManaged).value,
+    (Test / PB.protoSources) += (Test / sourceDirectory).value / "resources",
+
+    (Test / PB.targets) := Seq(
+      PB.gens.java -> target.value / "generated-test-sources"
     ),
 
     (assembly / test) := { },
diff --git a/python/pyspark/sql/protobuf/functions.py 
b/python/pyspark/sql/protobuf/functions.py
index 9f8b90095df..2059d868c7c 100644
--- a/python/pyspark/sql/protobuf/functions.py
+++ b/python/pyspark/sql/protobuf/functions.py
@@ -31,8 +31,8 @@ if TYPE_CHECKING:
 
 def from_protobuf(
     data: "ColumnOrName",
-    descFilePath: str,
     messageName: str,
+    descFilePath: str,
     options: Optional[Dict[str, str]] = None,
 ) -> Column:
     """
@@ -48,10 +48,10 @@ def from_protobuf(
     ----------
     data : :class:`~pyspark.sql.Column` or str
         the binary column.
-    descFilePath : str
-        the protobuf descriptor in Message GeneratedMessageV3 format.
     messageName: str
         the protobuf message name to look for in descriptor file.
+    descFilePath : str
+        the protobuf descriptor in Message GeneratedMessageV3 format.
     options : dict, optional
         options to control how the protobuf record is parsed.
 
@@ -80,10 +80,10 @@ def from_protobuf(
     ...         f.flush()
     ...         message_name = 'SimpleMessage'
     ...         proto_df = df.select(
-    ...             to_protobuf(df.value, desc_file_path, 
message_name).alias("value"))
+    ...             to_protobuf(df.value, message_name, 
desc_file_path).alias("value"))
     ...         proto_df.show(truncate=False)
     ...         proto_df = proto_df.select(
-    ...             from_protobuf(proto_df.value, desc_file_path, 
message_name).alias("value"))
+    ...             from_protobuf(proto_df.value, message_name, 
desc_file_path).alias("value"))
     ...         proto_df.show(truncate=False)
     +----------------------------------------+
     |value                                   |
@@ -101,7 +101,7 @@ def from_protobuf(
     assert sc is not None and sc._jvm is not None
     try:
         jc = sc._jvm.org.apache.spark.sql.protobuf.functions.from_protobuf(
-            _to_java_column(data), descFilePath, messageName, options or {}
+            _to_java_column(data), messageName, descFilePath, options or {}
         )
     except TypeError as e:
         if str(e) == "'JavaPackage' object is not callable":
@@ -110,7 +110,7 @@ def from_protobuf(
     return Column(jc)
 
 
-def to_protobuf(data: "ColumnOrName", descFilePath: str, messageName: str) -> 
Column:
+def to_protobuf(data: "ColumnOrName", messageName: str, descFilePath: str) -> 
Column:
     """
     Converts a column into binary of protobuf format.
 
@@ -120,10 +120,10 @@ def to_protobuf(data: "ColumnOrName", descFilePath: str, 
messageName: str) -> Co
     ----------
     data : :class:`~pyspark.sql.Column` or str
         the data column.
-    descFilePath : str
-        the protobuf descriptor in Message GeneratedMessageV3 format.
     messageName: str
         the protobuf message name to look for in descriptor file.
+    descFilePath : str
+        the protobuf descriptor in Message GeneratedMessageV3 format.
 
     Notes
     -----
@@ -150,7 +150,7 @@ def to_protobuf(data: "ColumnOrName", descFilePath: str, 
messageName: str) -> Co
     ...         f.flush()
     ...         message_name = 'SimpleMessage'
     ...         proto_df = df.select(
-    ...             to_protobuf(df.value, desc_file_path, 
message_name).alias("suite"))
+    ...             to_protobuf(df.value, message_name, 
desc_file_path).alias("suite"))
     ...         proto_df.show(truncate=False)
     +-------------------------------------------+
     |suite                                      |
@@ -162,7 +162,7 @@ def to_protobuf(data: "ColumnOrName", descFilePath: str, 
messageName: str) -> Co
     assert sc is not None and sc._jvm is not None
     try:
         jc = sc._jvm.org.apache.spark.sql.protobuf.functions.to_protobuf(
-            _to_java_column(data), descFilePath, messageName
+            _to_java_column(data), messageName, descFilePath
         )
     except TypeError as e:
         if str(e) == "'JavaPackage' object is not callable":


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to