This is an automated email from the ASF dual-hosted git repository.

maxgekk pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 136c7221fa6a [SPARK-50334][SQL] Extract common logic for reading the 
descriptor of PB file
136c7221fa6a is described below

commit 136c7221fa6a3cb542a7b432dc38032e32859679
Author: panbingkun <[email protected]>
AuthorDate: Thu Nov 21 12:07:31 2024 +0100

    [SPARK-50334][SQL] Extract common logic for reading the descriptor of PB 
file
    
    ### What changes were proposed in this pull request?
    The pr aims to
    - extract `common` logic for `reading the descriptor of PB file` to one 
place.
    - at the same time, when using the `from_protobuf` or `to_protobuf` 
function in `connect-client` and `spark-sql` (or `spark-shell`), the spark 
error-condition thrown when `the PB file is not found` or `read fails` will be 
aligned.
    
    ### Why are the changes needed?
    I found that the logic for `reading the descriptor of PB file` is scattered 
in various places in the `spark code repository`, eg:
    
https://github.com/apache/spark/blob/a01856de20013e5551d385ee000772049a0e1bc0/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/toFromProtobufSqlFunctions.scala#L37-L48
    
https://github.com/apache/spark/blob/a01856de20013e5551d385ee000772049a0e1bc0/sql/api/src/main/scala/org/apache/spark/sql/protobuf/functions.scala#L304-L315
    
https://github.com/apache/spark/blob/a01856de20013e5551d385ee000772049a0e1bc0/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufUtils.scala#L231-L241
    
    - I think we should gather it together to reduce the cost of maintenance.
    - Align `spark error-condition` to improve consistency in end-user 
experience.
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    Pass GA.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No.
    
    Closes #48874 from panbingkun/SPARK-50334.
    
    Authored-by: panbingkun <[email protected]>
    Signed-off-by: Max Gekk <[email protected]>
---
 .../spark/sql/protobuf/utils/ProtobufUtils.scala   | 17 ---------
 .../ProtobufCatalystDataConversionSuite.scala      |  7 ++--
 .../sql/protobuf/ProtobufFunctionsSuite.scala      |  9 ++---
 .../spark/sql/protobuf/ProtobufSerdeSuite.scala    |  9 ++---
 .../org/apache/spark/sql/protobuf/functions.scala  | 26 +++-----------
 .../org/apache/spark/sql/util/ProtobufUtils.scala  | 41 ++++++++++++++++++++++
 .../expressions/toFromProtobufSqlFunctions.scala   | 28 ++-------------
 7 files changed, 62 insertions(+), 75 deletions(-)

diff --git 
a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufUtils.scala
 
b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufUtils.scala
index fee1bcdc9670..3d7bba7a82e8 100644
--- 
a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufUtils.scala
+++ 
b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufUtils.scala
@@ -17,19 +17,14 @@
 
 package org.apache.spark.sql.protobuf.utils
 
-import java.io.File
-import java.io.FileNotFoundException
-import java.nio.file.NoSuchFileException
 import java.util.Locale
 
 import scala.jdk.CollectionConverters._
-import scala.util.control.NonFatal
 
 import com.google.protobuf.{DescriptorProtos, Descriptors, 
InvalidProtocolBufferException, Message}
 import com.google.protobuf.DescriptorProtos.{FileDescriptorProto, 
FileDescriptorSet}
 import com.google.protobuf.Descriptors.{Descriptor, FieldDescriptor}
 import com.google.protobuf.TypeRegistry
-import org.apache.commons.io.FileUtils
 
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.errors.QueryCompilationErrors
@@ -228,18 +223,6 @@ private[sql] object ProtobufUtils extends Logging {
     }
   }
 
-  def readDescriptorFileContent(filePath: String): Array[Byte] = {
-    try {
-      FileUtils.readFileToByteArray(new File(filePath))
-    } catch {
-      case ex: FileNotFoundException =>
-        throw QueryCompilationErrors.cannotFindDescriptorFileError(filePath, 
ex)
-      case ex: NoSuchFileException =>
-        throw QueryCompilationErrors.cannotFindDescriptorFileError(filePath, 
ex)
-      case NonFatal(ex) => throw 
QueryCompilationErrors.descriptorParseError(ex)
-    }
-  }
-
   private def parseFileDescriptorSet(bytes: Array[Byte]): 
List[Descriptors.FileDescriptor] = {
     var fileDescriptorSet: DescriptorProtos.FileDescriptorSet = null
     try {
diff --git 
a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufCatalystDataConversionSuite.scala
 
b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufCatalystDataConversionSuite.scala
index ad6a88640140..abae1d622d3c 100644
--- 
a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufCatalystDataConversionSuite.scala
+++ 
b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufCatalystDataConversionSuite.scala
@@ -29,6 +29,7 @@ import org.apache.spark.sql.protobuf.utils.{ProtobufUtils, 
SchemaConverters}
 import org.apache.spark.sql.sources.{EqualTo, Not}
 import org.apache.spark.sql.test.SharedSparkSession
 import org.apache.spark.sql.types._
+import org.apache.spark.sql.util.{ProtobufUtils => CommonProtobufUtils}
 import org.apache.spark.unsafe.types.UTF8String
 import org.apache.spark.util.ArrayImplicits._
 
@@ -39,7 +40,7 @@ class ProtobufCatalystDataConversionSuite
     with ProtobufTestBase {
 
   private val testFileDescFile = protobufDescriptorFile("catalyst_types.desc")
-  private val testFileDesc = 
ProtobufUtils.readDescriptorFileContent(testFileDescFile)
+  private val testFileDesc = 
CommonProtobufUtils.readDescriptorFileContent(testFileDescFile)
   private val javaClassNamePrefix = 
"org.apache.spark.sql.protobuf.protos.CatalystTypes$"
 
   private def checkResultWithEval(
@@ -47,7 +48,7 @@ class ProtobufCatalystDataConversionSuite
       descFilePath: String,
       messageName: String,
       expected: Any): Unit = {
-    val descBytes = ProtobufUtils.readDescriptorFileContent(descFilePath)
+    val descBytes = CommonProtobufUtils.readDescriptorFileContent(descFilePath)
     withClue("(Eval check with Java class name)") {
       val className = s"$javaClassNamePrefix$messageName"
       checkEvaluation(
@@ -72,7 +73,7 @@ class ProtobufCatalystDataConversionSuite
       actualSchema: String,
       badSchema: String): Unit = {
 
-    val descBytes = ProtobufUtils.readDescriptorFileContent(descFilePath)
+    val descBytes = CommonProtobufUtils.readDescriptorFileContent(descFilePath)
     val binary = CatalystDataToProtobuf(data, actualSchema, Some(descBytes))
 
     intercept[Exception] {
diff --git 
a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala
 
b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala
index 3eaa91e472c4..44a8339ac1f0 100644
--- 
a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala
+++ 
b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala
@@ -33,6 +33,7 @@ import org.apache.spark.sql.protobuf.utils.ProtobufOptions
 import org.apache.spark.sql.protobuf.utils.ProtobufUtils
 import org.apache.spark.sql.test.SharedSparkSession
 import org.apache.spark.sql.types._
+import org.apache.spark.sql.util.{ProtobufUtils => CommonProtobufUtils}
 
 class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with 
ProtobufTestBase
   with Serializable {
@@ -40,11 +41,11 @@ class ProtobufFunctionsSuite extends QueryTest with 
SharedSparkSession with Prot
   import testImplicits._
 
   val testFileDescFile = protobufDescriptorFile("functions_suite.desc")
-  private val testFileDesc = 
ProtobufUtils.readDescriptorFileContent(testFileDescFile)
+  private val testFileDesc = 
CommonProtobufUtils.readDescriptorFileContent(testFileDescFile)
   private val javaClassNamePrefix = 
"org.apache.spark.sql.protobuf.protos.SimpleMessageProtos$"
 
   val proto2FileDescFile = protobufDescriptorFile("proto2_messages.desc")
-  val proto2FileDesc = 
ProtobufUtils.readDescriptorFileContent(proto2FileDescFile)
+  val proto2FileDesc = 
CommonProtobufUtils.readDescriptorFileContent(proto2FileDescFile)
   private val proto2JavaClassNamePrefix = 
"org.apache.spark.sql.protobuf.protos.Proto2Messages$"
 
   private def emptyBinaryDF = Seq(Array[Byte]()).toDF("binary")
@@ -467,7 +468,7 @@ class ProtobufFunctionsSuite extends QueryTest with 
SharedSparkSession with Prot
 
   test("Handle extra fields : oldProducer -> newConsumer") {
     val catalystTypesFile = protobufDescriptorFile("catalyst_types.desc")
-    val descBytes = ProtobufUtils.readDescriptorFileContent(catalystTypesFile)
+    val descBytes = 
CommonProtobufUtils.readDescriptorFileContent(catalystTypesFile)
 
     val oldProducer = ProtobufUtils.buildDescriptor(descBytes, "oldProducer")
     val newConsumer = ProtobufUtils.buildDescriptor(descBytes, "newConsumer")
@@ -509,7 +510,7 @@ class ProtobufFunctionsSuite extends QueryTest with 
SharedSparkSession with Prot
 
   test("Handle extra fields : newProducer -> oldConsumer") {
     val catalystTypesFile = protobufDescriptorFile("catalyst_types.desc")
-    val descBytes = ProtobufUtils.readDescriptorFileContent(catalystTypesFile)
+    val descBytes = 
CommonProtobufUtils.readDescriptorFileContent(catalystTypesFile)
 
     val newProducer = ProtobufUtils.buildDescriptor(descBytes, "newProducer")
     val oldConsumer = ProtobufUtils.buildDescriptor(descBytes, "oldConsumer")
diff --git 
a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufSerdeSuite.scala
 
b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufSerdeSuite.scala
index 2737bb9feb3a..f3bd49e1b24a 100644
--- 
a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufSerdeSuite.scala
+++ 
b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufSerdeSuite.scala
@@ -26,6 +26,7 @@ import 
org.apache.spark.sql.catalyst.expressions.Cast.toSQLType
 import org.apache.spark.sql.protobuf.utils.ProtobufUtils
 import org.apache.spark.sql.test.SharedSparkSession
 import org.apache.spark.sql.types.{IntegerType, StructType}
+import org.apache.spark.sql.util.{ProtobufUtils => CommonProtobufUtils}
 
 /**
  * Tests for [[ProtobufSerializer]] and [[ProtobufDeserializer]] with a more 
specific focus on
@@ -37,12 +38,12 @@ class ProtobufSerdeSuite extends SharedSparkSession with 
ProtobufTestBase {
   import ProtoSerdeSuite.MatchType._
 
   private val testFileDescFile = protobufDescriptorFile("serde_suite.desc")
-  private val testFileDesc = 
ProtobufUtils.readDescriptorFileContent(testFileDescFile)
+  private val testFileDesc = 
CommonProtobufUtils.readDescriptorFileContent(testFileDescFile)
 
   private val javaClassNamePrefix = 
"org.apache.spark.sql.protobuf.protos.SerdeSuiteProtos$"
 
   private val proto2DescFile = protobufDescriptorFile("proto2_messages.desc")
-  private val proto2Desc = 
ProtobufUtils.readDescriptorFileContent(proto2DescFile)
+  private val proto2Desc = 
CommonProtobufUtils.readDescriptorFileContent(proto2DescFile)
 
   test("Test basic conversion") {
     withFieldMatchType { fieldMatch =>
@@ -215,7 +216,7 @@ class ProtobufSerdeSuite extends SharedSparkSession with 
ProtobufTestBase {
 
     val e1 = intercept[AnalysisException] {
       ProtobufUtils.buildDescriptor(
-        ProtobufUtils.readDescriptorFileContent(fileDescFile),
+        CommonProtobufUtils.readDescriptorFileContent(fileDescFile),
         "SerdeBasicMessage"
       )
     }
@@ -225,7 +226,7 @@ class ProtobufSerdeSuite extends SharedSparkSession with 
ProtobufTestBase {
       condition = "CANNOT_PARSE_PROTOBUF_DESCRIPTOR")
 
     val basicMessageDescWithoutImports = descriptorSetWithoutImports(
-      ProtobufUtils.readDescriptorFileContent(
+      CommonProtobufUtils.readDescriptorFileContent(
         protobufDescriptorFile("basicmessage.desc")
       ),
       "BasicMessage"
diff --git 
a/sql/api/src/main/scala/org/apache/spark/sql/protobuf/functions.scala 
b/sql/api/src/main/scala/org/apache/spark/sql/protobuf/functions.scala
index ea9e3c429d65..fab5cdc8de1b 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/protobuf/functions.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/protobuf/functions.scala
@@ -16,16 +16,12 @@
  */
 package org.apache.spark.sql.protobuf
 
-import java.io.FileNotFoundException
-import java.nio.file.{Files, NoSuchFileException, Paths}
-
 import scala.jdk.CollectionConverters._
-import scala.util.control.NonFatal
 
 import org.apache.spark.annotation.Experimental
 import org.apache.spark.sql.Column
-import org.apache.spark.sql.errors.CompilationErrors
 import org.apache.spark.sql.functions.lit
+import org.apache.spark.sql.util.ProtobufUtils
 
 // scalastyle:off: object.name
 object functions {
@@ -51,7 +47,7 @@ object functions {
       messageName: String,
       descFilePath: String,
       options: java.util.Map[String, String]): Column = {
-    val descriptorFileContent = readDescriptorFileContent(descFilePath)
+    val descriptorFileContent = 
ProtobufUtils.readDescriptorFileContent(descFilePath)
     from_protobuf(data, messageName, descriptorFileContent, options)
   }
 
@@ -98,7 +94,7 @@ object functions {
    */
   @Experimental
   def from_protobuf(data: Column, messageName: String, descFilePath: String): 
Column = {
-    val fileContent = readDescriptorFileContent(descFilePath)
+    val fileContent = ProtobufUtils.readDescriptorFileContent(descFilePath)
     from_protobuf(data, messageName, fileContent)
   }
 
@@ -226,7 +222,7 @@ object functions {
       messageName: String,
       descFilePath: String,
       options: java.util.Map[String, String]): Column = {
-    val fileContent = readDescriptorFileContent(descFilePath)
+    val fileContent = ProtobufUtils.readDescriptorFileContent(descFilePath)
     to_protobuf(data, messageName, fileContent, options)
   }
 
@@ -299,18 +295,4 @@ object functions {
       options: java.util.Map[String, String]): Column = {
     Column.fnWithOptions("to_protobuf", options.asScala.iterator, data, 
lit(messageClassName))
   }
-
-  // This method is copied from 
org.apache.spark.sql.protobuf.util.ProtobufUtils
-  private def readDescriptorFileContent(filePath: String): Array[Byte] = {
-    try {
-      Files.readAllBytes(Paths.get(filePath))
-    } catch {
-      case ex: FileNotFoundException =>
-        throw CompilationErrors.cannotFindDescriptorFileError(filePath, ex)
-      case ex: NoSuchFileException =>
-        throw CompilationErrors.cannotFindDescriptorFileError(filePath, ex)
-      case NonFatal(ex) =>
-        throw CompilationErrors.descriptorParseError(ex)
-    }
-  }
 }
diff --git 
a/sql/api/src/main/scala/org/apache/spark/sql/util/ProtobufUtils.scala 
b/sql/api/src/main/scala/org/apache/spark/sql/util/ProtobufUtils.scala
new file mode 100644
index 000000000000..11f35ceb060c
--- /dev/null
+++ b/sql/api/src/main/scala/org/apache/spark/sql/util/ProtobufUtils.scala
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.util
+
+import java.io.{File, FileNotFoundException}
+import java.nio.file.NoSuchFileException
+
+import scala.util.control.NonFatal
+
+import org.apache.commons.io.FileUtils
+
+import org.apache.spark.sql.errors.CompilationErrors
+
+object ProtobufUtils {
+  def readDescriptorFileContent(filePath: String): Array[Byte] = {
+    try {
+      FileUtils.readFileToByteArray(new File(filePath))
+    } catch {
+      case ex: FileNotFoundException =>
+        throw CompilationErrors.cannotFindDescriptorFileError(filePath, ex)
+      case ex: NoSuchFileException =>
+        throw CompilationErrors.cannotFindDescriptorFileError(filePath, ex)
+      case NonFatal(ex) => throw CompilationErrors.descriptorParseError(ex)
+    }
+  }
+}
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/toFromProtobufSqlFunctions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/toFromProtobufSqlFunctions.scala
index ad9610ea0c78..96bcf49dbd09 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/toFromProtobufSqlFunctions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/toFromProtobufSqlFunctions.scala
@@ -17,37 +17,15 @@
 
 package org.apache.spark.sql.catalyst.expressions
 
-import java.io.File
-import java.io.FileNotFoundException
-import java.nio.file.NoSuchFileException
-
-import scala.util.control.NonFatal
-
-import org.apache.commons.io.FileUtils
-
 import org.apache.spark.sql.catalyst.analysis.FunctionRegistry
 import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
 import org.apache.spark.sql.catalyst.util.ArrayBasedMapData
 import org.apache.spark.sql.errors.QueryCompilationErrors
 import org.apache.spark.sql.types.{BinaryType, MapType, NullType, StringType}
+import org.apache.spark.sql.util.ProtobufUtils
 import org.apache.spark.unsafe.types.UTF8String
 import org.apache.spark.util.Utils
 
-object ProtobufHelper {
-  def readDescriptorFileContent(filePath: String): Array[Byte] = {
-    try {
-      FileUtils.readFileToByteArray(new File(filePath))
-    } catch {
-      case ex: FileNotFoundException =>
-        throw new RuntimeException(s"Cannot find descriptor file at path: 
$filePath", ex)
-      case ex: NoSuchFileException =>
-        throw new RuntimeException(s"Cannot find descriptor file at path: 
$filePath", ex)
-      case NonFatal(ex) =>
-        throw new RuntimeException(s"Failed to read the descriptor file: 
$filePath", ex)
-    }
-  }
-}
-
 /**
  * Converts a binary column of Protobuf format into its corresponding catalyst 
value.
  * The Protobuf definition is provided through Protobuf <i>descriptor file</i>.
@@ -163,7 +141,7 @@ case class FromProtobuf(
     }
     val descFilePathValue: Option[Array[Byte]] = descFilePath.eval() match {
       case s: UTF8String if s.toString.isEmpty => None
-      case s: UTF8String => 
Some(ProtobufHelper.readDescriptorFileContent(s.toString))
+      case s: UTF8String => 
Some(ProtobufUtils.readDescriptorFileContent(s.toString))
       case bytes: Array[Byte] if bytes.isEmpty => None
       case bytes: Array[Byte] => Some(bytes)
       case null => None
@@ -300,7 +278,7 @@ case class ToProtobuf(
         s.toString
     }
     val descFilePathValue: Option[Array[Byte]] = descFilePath.eval() match {
-      case s: UTF8String => 
Some(ProtobufHelper.readDescriptorFileContent(s.toString))
+      case s: UTF8String => 
Some(ProtobufUtils.readDescriptorFileContent(s.toString))
       case bytes: Array[Byte] => Some(bytes)
       case null => None
     }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to