This is an automated email from the ASF dual-hosted git repository.
maxgekk pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 136c7221fa6a [SPARK-50334][SQL] Extract common logic for reading the
descriptor of PB file
136c7221fa6a is described below
commit 136c7221fa6a3cb542a7b432dc38032e32859679
Author: panbingkun <[email protected]>
AuthorDate: Thu Nov 21 12:07:31 2024 +0100
[SPARK-50334][SQL] Extract common logic for reading the descriptor of PB
file
### What changes were proposed in this pull request?
The pr aims to
- extract `common` logic for `reading the descriptor of PB file` to one
place.
- at the same time, when using the `from_protobuf` or `to_protobuf`
function in `connect-client` and `spark-sql` (or `spark-shell`), the spark
error-condition thrown when `the PB file is not found` or `read fails` will be
aligned.
### Why are the changes needed?
I found that the logic for `reading the descriptor of PB file` is scattered
in various places in the `spark code repository`, eg:
https://github.com/apache/spark/blob/a01856de20013e5551d385ee000772049a0e1bc0/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/toFromProtobufSqlFunctions.scala#L37-L48
https://github.com/apache/spark/blob/a01856de20013e5551d385ee000772049a0e1bc0/sql/api/src/main/scala/org/apache/spark/sql/protobuf/functions.scala#L304-L315
https://github.com/apache/spark/blob/a01856de20013e5551d385ee000772049a0e1bc0/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufUtils.scala#L231-L241
- I think we should gather it together to reduce the cost of maintenance.
- Align `spark error-condition` to improve consistency in end-user
experience.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Pass GA.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #48874 from panbingkun/SPARK-50334.
Authored-by: panbingkun <[email protected]>
Signed-off-by: Max Gekk <[email protected]>
---
.../spark/sql/protobuf/utils/ProtobufUtils.scala | 17 ---------
.../ProtobufCatalystDataConversionSuite.scala | 7 ++--
.../sql/protobuf/ProtobufFunctionsSuite.scala | 9 ++---
.../spark/sql/protobuf/ProtobufSerdeSuite.scala | 9 ++---
.../org/apache/spark/sql/protobuf/functions.scala | 26 +++-----------
.../org/apache/spark/sql/util/ProtobufUtils.scala | 41 ++++++++++++++++++++++
.../expressions/toFromProtobufSqlFunctions.scala | 28 ++-------------
7 files changed, 62 insertions(+), 75 deletions(-)
diff --git
a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufUtils.scala
b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufUtils.scala
index fee1bcdc9670..3d7bba7a82e8 100644
---
a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufUtils.scala
+++
b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufUtils.scala
@@ -17,19 +17,14 @@
package org.apache.spark.sql.protobuf.utils
-import java.io.File
-import java.io.FileNotFoundException
-import java.nio.file.NoSuchFileException
import java.util.Locale
import scala.jdk.CollectionConverters._
-import scala.util.control.NonFatal
import com.google.protobuf.{DescriptorProtos, Descriptors,
InvalidProtocolBufferException, Message}
import com.google.protobuf.DescriptorProtos.{FileDescriptorProto,
FileDescriptorSet}
import com.google.protobuf.Descriptors.{Descriptor, FieldDescriptor}
import com.google.protobuf.TypeRegistry
-import org.apache.commons.io.FileUtils
import org.apache.spark.internal.Logging
import org.apache.spark.sql.errors.QueryCompilationErrors
@@ -228,18 +223,6 @@ private[sql] object ProtobufUtils extends Logging {
}
}
- def readDescriptorFileContent(filePath: String): Array[Byte] = {
- try {
- FileUtils.readFileToByteArray(new File(filePath))
- } catch {
- case ex: FileNotFoundException =>
- throw QueryCompilationErrors.cannotFindDescriptorFileError(filePath,
ex)
- case ex: NoSuchFileException =>
- throw QueryCompilationErrors.cannotFindDescriptorFileError(filePath,
ex)
- case NonFatal(ex) => throw
QueryCompilationErrors.descriptorParseError(ex)
- }
- }
-
private def parseFileDescriptorSet(bytes: Array[Byte]):
List[Descriptors.FileDescriptor] = {
var fileDescriptorSet: DescriptorProtos.FileDescriptorSet = null
try {
diff --git
a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufCatalystDataConversionSuite.scala
b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufCatalystDataConversionSuite.scala
index ad6a88640140..abae1d622d3c 100644
---
a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufCatalystDataConversionSuite.scala
+++
b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufCatalystDataConversionSuite.scala
@@ -29,6 +29,7 @@ import org.apache.spark.sql.protobuf.utils.{ProtobufUtils,
SchemaConverters}
import org.apache.spark.sql.sources.{EqualTo, Not}
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types._
+import org.apache.spark.sql.util.{ProtobufUtils => CommonProtobufUtils}
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.ArrayImplicits._
@@ -39,7 +40,7 @@ class ProtobufCatalystDataConversionSuite
with ProtobufTestBase {
private val testFileDescFile = protobufDescriptorFile("catalyst_types.desc")
- private val testFileDesc =
ProtobufUtils.readDescriptorFileContent(testFileDescFile)
+ private val testFileDesc =
CommonProtobufUtils.readDescriptorFileContent(testFileDescFile)
private val javaClassNamePrefix =
"org.apache.spark.sql.protobuf.protos.CatalystTypes$"
private def checkResultWithEval(
@@ -47,7 +48,7 @@ class ProtobufCatalystDataConversionSuite
descFilePath: String,
messageName: String,
expected: Any): Unit = {
- val descBytes = ProtobufUtils.readDescriptorFileContent(descFilePath)
+ val descBytes = CommonProtobufUtils.readDescriptorFileContent(descFilePath)
withClue("(Eval check with Java class name)") {
val className = s"$javaClassNamePrefix$messageName"
checkEvaluation(
@@ -72,7 +73,7 @@ class ProtobufCatalystDataConversionSuite
actualSchema: String,
badSchema: String): Unit = {
- val descBytes = ProtobufUtils.readDescriptorFileContent(descFilePath)
+ val descBytes = CommonProtobufUtils.readDescriptorFileContent(descFilePath)
val binary = CatalystDataToProtobuf(data, actualSchema, Some(descBytes))
intercept[Exception] {
diff --git
a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala
b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala
index 3eaa91e472c4..44a8339ac1f0 100644
---
a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala
+++
b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala
@@ -33,6 +33,7 @@ import org.apache.spark.sql.protobuf.utils.ProtobufOptions
import org.apache.spark.sql.protobuf.utils.ProtobufUtils
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types._
+import org.apache.spark.sql.util.{ProtobufUtils => CommonProtobufUtils}
class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with
ProtobufTestBase
with Serializable {
@@ -40,11 +41,11 @@ class ProtobufFunctionsSuite extends QueryTest with
SharedSparkSession with Prot
import testImplicits._
val testFileDescFile = protobufDescriptorFile("functions_suite.desc")
- private val testFileDesc =
ProtobufUtils.readDescriptorFileContent(testFileDescFile)
+ private val testFileDesc =
CommonProtobufUtils.readDescriptorFileContent(testFileDescFile)
private val javaClassNamePrefix =
"org.apache.spark.sql.protobuf.protos.SimpleMessageProtos$"
val proto2FileDescFile = protobufDescriptorFile("proto2_messages.desc")
- val proto2FileDesc =
ProtobufUtils.readDescriptorFileContent(proto2FileDescFile)
+ val proto2FileDesc =
CommonProtobufUtils.readDescriptorFileContent(proto2FileDescFile)
private val proto2JavaClassNamePrefix =
"org.apache.spark.sql.protobuf.protos.Proto2Messages$"
private def emptyBinaryDF = Seq(Array[Byte]()).toDF("binary")
@@ -467,7 +468,7 @@ class ProtobufFunctionsSuite extends QueryTest with
SharedSparkSession with Prot
test("Handle extra fields : oldProducer -> newConsumer") {
val catalystTypesFile = protobufDescriptorFile("catalyst_types.desc")
- val descBytes = ProtobufUtils.readDescriptorFileContent(catalystTypesFile)
+ val descBytes =
CommonProtobufUtils.readDescriptorFileContent(catalystTypesFile)
val oldProducer = ProtobufUtils.buildDescriptor(descBytes, "oldProducer")
val newConsumer = ProtobufUtils.buildDescriptor(descBytes, "newConsumer")
@@ -509,7 +510,7 @@ class ProtobufFunctionsSuite extends QueryTest with
SharedSparkSession with Prot
test("Handle extra fields : newProducer -> oldConsumer") {
val catalystTypesFile = protobufDescriptorFile("catalyst_types.desc")
- val descBytes = ProtobufUtils.readDescriptorFileContent(catalystTypesFile)
+ val descBytes =
CommonProtobufUtils.readDescriptorFileContent(catalystTypesFile)
val newProducer = ProtobufUtils.buildDescriptor(descBytes, "newProducer")
val oldConsumer = ProtobufUtils.buildDescriptor(descBytes, "oldConsumer")
diff --git
a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufSerdeSuite.scala
b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufSerdeSuite.scala
index 2737bb9feb3a..f3bd49e1b24a 100644
---
a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufSerdeSuite.scala
+++
b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufSerdeSuite.scala
@@ -26,6 +26,7 @@ import
org.apache.spark.sql.catalyst.expressions.Cast.toSQLType
import org.apache.spark.sql.protobuf.utils.ProtobufUtils
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.{IntegerType, StructType}
+import org.apache.spark.sql.util.{ProtobufUtils => CommonProtobufUtils}
/**
* Tests for [[ProtobufSerializer]] and [[ProtobufDeserializer]] with a more
specific focus on
@@ -37,12 +38,12 @@ class ProtobufSerdeSuite extends SharedSparkSession with
ProtobufTestBase {
import ProtoSerdeSuite.MatchType._
private val testFileDescFile = protobufDescriptorFile("serde_suite.desc")
- private val testFileDesc =
ProtobufUtils.readDescriptorFileContent(testFileDescFile)
+ private val testFileDesc =
CommonProtobufUtils.readDescriptorFileContent(testFileDescFile)
private val javaClassNamePrefix =
"org.apache.spark.sql.protobuf.protos.SerdeSuiteProtos$"
private val proto2DescFile = protobufDescriptorFile("proto2_messages.desc")
- private val proto2Desc =
ProtobufUtils.readDescriptorFileContent(proto2DescFile)
+ private val proto2Desc =
CommonProtobufUtils.readDescriptorFileContent(proto2DescFile)
test("Test basic conversion") {
withFieldMatchType { fieldMatch =>
@@ -215,7 +216,7 @@ class ProtobufSerdeSuite extends SharedSparkSession with
ProtobufTestBase {
val e1 = intercept[AnalysisException] {
ProtobufUtils.buildDescriptor(
- ProtobufUtils.readDescriptorFileContent(fileDescFile),
+ CommonProtobufUtils.readDescriptorFileContent(fileDescFile),
"SerdeBasicMessage"
)
}
@@ -225,7 +226,7 @@ class ProtobufSerdeSuite extends SharedSparkSession with
ProtobufTestBase {
condition = "CANNOT_PARSE_PROTOBUF_DESCRIPTOR")
val basicMessageDescWithoutImports = descriptorSetWithoutImports(
- ProtobufUtils.readDescriptorFileContent(
+ CommonProtobufUtils.readDescriptorFileContent(
protobufDescriptorFile("basicmessage.desc")
),
"BasicMessage"
diff --git
a/sql/api/src/main/scala/org/apache/spark/sql/protobuf/functions.scala
b/sql/api/src/main/scala/org/apache/spark/sql/protobuf/functions.scala
index ea9e3c429d65..fab5cdc8de1b 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/protobuf/functions.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/protobuf/functions.scala
@@ -16,16 +16,12 @@
*/
package org.apache.spark.sql.protobuf
-import java.io.FileNotFoundException
-import java.nio.file.{Files, NoSuchFileException, Paths}
-
import scala.jdk.CollectionConverters._
-import scala.util.control.NonFatal
import org.apache.spark.annotation.Experimental
import org.apache.spark.sql.Column
-import org.apache.spark.sql.errors.CompilationErrors
import org.apache.spark.sql.functions.lit
+import org.apache.spark.sql.util.ProtobufUtils
// scalastyle:off: object.name
object functions {
@@ -51,7 +47,7 @@ object functions {
messageName: String,
descFilePath: String,
options: java.util.Map[String, String]): Column = {
- val descriptorFileContent = readDescriptorFileContent(descFilePath)
+ val descriptorFileContent =
ProtobufUtils.readDescriptorFileContent(descFilePath)
from_protobuf(data, messageName, descriptorFileContent, options)
}
@@ -98,7 +94,7 @@ object functions {
*/
@Experimental
def from_protobuf(data: Column, messageName: String, descFilePath: String):
Column = {
- val fileContent = readDescriptorFileContent(descFilePath)
+ val fileContent = ProtobufUtils.readDescriptorFileContent(descFilePath)
from_protobuf(data, messageName, fileContent)
}
@@ -226,7 +222,7 @@ object functions {
messageName: String,
descFilePath: String,
options: java.util.Map[String, String]): Column = {
- val fileContent = readDescriptorFileContent(descFilePath)
+ val fileContent = ProtobufUtils.readDescriptorFileContent(descFilePath)
to_protobuf(data, messageName, fileContent, options)
}
@@ -299,18 +295,4 @@ object functions {
options: java.util.Map[String, String]): Column = {
Column.fnWithOptions("to_protobuf", options.asScala.iterator, data,
lit(messageClassName))
}
-
- // This method is copied from
org.apache.spark.sql.protobuf.util.ProtobufUtils
- private def readDescriptorFileContent(filePath: String): Array[Byte] = {
- try {
- Files.readAllBytes(Paths.get(filePath))
- } catch {
- case ex: FileNotFoundException =>
- throw CompilationErrors.cannotFindDescriptorFileError(filePath, ex)
- case ex: NoSuchFileException =>
- throw CompilationErrors.cannotFindDescriptorFileError(filePath, ex)
- case NonFatal(ex) =>
- throw CompilationErrors.descriptorParseError(ex)
- }
- }
}
diff --git
a/sql/api/src/main/scala/org/apache/spark/sql/util/ProtobufUtils.scala
b/sql/api/src/main/scala/org/apache/spark/sql/util/ProtobufUtils.scala
new file mode 100644
index 000000000000..11f35ceb060c
--- /dev/null
+++ b/sql/api/src/main/scala/org/apache/spark/sql/util/ProtobufUtils.scala
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.util
+
+import java.io.{File, FileNotFoundException}
+import java.nio.file.NoSuchFileException
+
+import scala.util.control.NonFatal
+
+import org.apache.commons.io.FileUtils
+
+import org.apache.spark.sql.errors.CompilationErrors
+
+object ProtobufUtils {
+ def readDescriptorFileContent(filePath: String): Array[Byte] = {
+ try {
+ FileUtils.readFileToByteArray(new File(filePath))
+ } catch {
+ case ex: FileNotFoundException =>
+ throw CompilationErrors.cannotFindDescriptorFileError(filePath, ex)
+ case ex: NoSuchFileException =>
+ throw CompilationErrors.cannotFindDescriptorFileError(filePath, ex)
+ case NonFatal(ex) => throw CompilationErrors.descriptorParseError(ex)
+ }
+ }
+}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/toFromProtobufSqlFunctions.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/toFromProtobufSqlFunctions.scala
index ad9610ea0c78..96bcf49dbd09 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/toFromProtobufSqlFunctions.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/toFromProtobufSqlFunctions.scala
@@ -17,37 +17,15 @@
package org.apache.spark.sql.catalyst.expressions
-import java.io.File
-import java.io.FileNotFoundException
-import java.nio.file.NoSuchFileException
-
-import scala.util.control.NonFatal
-
-import org.apache.commons.io.FileUtils
-
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.util.ArrayBasedMapData
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.types.{BinaryType, MapType, NullType, StringType}
+import org.apache.spark.sql.util.ProtobufUtils
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.Utils
-object ProtobufHelper {
- def readDescriptorFileContent(filePath: String): Array[Byte] = {
- try {
- FileUtils.readFileToByteArray(new File(filePath))
- } catch {
- case ex: FileNotFoundException =>
- throw new RuntimeException(s"Cannot find descriptor file at path:
$filePath", ex)
- case ex: NoSuchFileException =>
- throw new RuntimeException(s"Cannot find descriptor file at path:
$filePath", ex)
- case NonFatal(ex) =>
- throw new RuntimeException(s"Failed to read the descriptor file:
$filePath", ex)
- }
- }
-}
-
/**
* Converts a binary column of Protobuf format into its corresponding catalyst
value.
* The Protobuf definition is provided through Protobuf <i>descriptor file</i>.
@@ -163,7 +141,7 @@ case class FromProtobuf(
}
val descFilePathValue: Option[Array[Byte]] = descFilePath.eval() match {
case s: UTF8String if s.toString.isEmpty => None
- case s: UTF8String =>
Some(ProtobufHelper.readDescriptorFileContent(s.toString))
+ case s: UTF8String =>
Some(ProtobufUtils.readDescriptorFileContent(s.toString))
case bytes: Array[Byte] if bytes.isEmpty => None
case bytes: Array[Byte] => Some(bytes)
case null => None
@@ -300,7 +278,7 @@ case class ToProtobuf(
s.toString
}
val descFilePathValue: Option[Array[Byte]] = descFilePath.eval() match {
- case s: UTF8String =>
Some(ProtobufHelper.readDescriptorFileContent(s.toString))
+ case s: UTF8String =>
Some(ProtobufUtils.readDescriptorFileContent(s.toString))
case bytes: Array[Byte] => Some(bytes)
case null => None
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]