This is an automated email from the ASF dual-hosted git repository.

yao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 461026c62ba1 [SPARK-46754][SQL][AVRO] Fix compression code resolution 
in avro table definition and write options
461026c62ba1 is described below

commit 461026c62ba19e4248d529c4971d3ba74fba2a2d
Author: Kent Yao <y...@apache.org>
AuthorDate: Thu Jan 18 15:53:04 2024 +0800

    [SPARK-46754][SQL][AVRO] Fix compression code resolution in avro table 
definition and write options
    
    ### What changes were proposed in this pull request?
    
    This PR fixes the case sensitivity of 'compression' in the avro table 
definition and the write options, in order to make it consistent with other 
file sources. Also, the current logic for dealing invalid codec names is 
unreachable.
    
    ### Why are the changes needed?
    
    bugfix
    
    ### Does this PR introduce _any_ user-facing change?
    
    yes, 'compression'='Xz', 'compression'='XZ' now works as well as 
'compression'='xz'
    ### How was this patch tested?
    
    new tests
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    no
    
    Closes #44780 from yaooqinn/SPARK-46754.
    
    Authored-by: Kent Yao <y...@apache.org>
    Signed-off-by: Kent Yao <y...@apache.org>
---
 .../org/apache/spark/sql/avro/AvroUtils.scala      | 37 ++++++++++++----------
 .../org/apache/spark/sql/avro/AvroCodecSuite.scala | 30 ++++++++++++++----
 2 files changed, 43 insertions(+), 24 deletions(-)

diff --git 
a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala 
b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala
index 25e6aec4d84a..3910cf540628 100644
--- a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala
+++ b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala
@@ -30,7 +30,7 @@ import org.apache.hadoop.conf.Configuration
 import org.apache.hadoop.fs.FileStatus
 import org.apache.hadoop.mapreduce.Job
 
-import org.apache.spark.SparkException
+import org.apache.spark.{SparkException, SparkIllegalArgumentException}
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.SparkSession
 import org.apache.spark.sql.avro.AvroCompressionCodec._
@@ -102,22 +102,25 @@ private[sql] object AvroUtils extends Logging {
 
     AvroJob.setOutputKeySchema(job, outputAvroSchema)
 
-    if (parsedOptions.compression == UNCOMPRESSED.lowerCaseName()) {
-      job.getConfiguration.setBoolean("mapred.output.compress", false)
-    } else {
-      job.getConfiguration.setBoolean("mapred.output.compress", true)
-      logInfo(s"Compressing Avro output using the ${parsedOptions.compression} 
codec")
-      val codec = AvroCompressionCodec.fromString(parsedOptions.compression) 
match {
-        case DEFLATE =>
-          val deflateLevel = sqlConf.avroDeflateLevel
-          logInfo(s"Avro compression level $deflateLevel will be used for " +
-            s"${DEFLATE.getCodecName()} codec.")
-          job.getConfiguration.setInt(AvroOutputFormat.DEFLATE_LEVEL_KEY, 
deflateLevel)
-          DEFLATE.getCodecName()
-        case codec @ (SNAPPY | BZIP2 | XZ | ZSTANDARD) => codec.getCodecName()
-        case unknown => throw new IllegalArgumentException(s"Invalid 
compression codec: $unknown")
-      }
-      job.getConfiguration.set(AvroJob.CONF_OUTPUT_CODEC, codec)
+    parsedOptions.compression.toLowerCase(Locale.ROOT) match {
+      case codecName if AvroCompressionCodec.values().exists(c => 
c.lowerCaseName() == codecName) =>
+        AvroCompressionCodec.fromString(codecName) match {
+          case UNCOMPRESSED =>
+            job.getConfiguration.setBoolean("mapred.output.compress", false)
+          case compressed =>
+            job.getConfiguration.setBoolean("mapred.output.compress", true)
+            job.getConfiguration.set(AvroJob.CONF_OUTPUT_CODEC, 
compressed.getCodecName)
+            if (compressed == DEFLATE) {
+              val deflateLevel = sqlConf.avroDeflateLevel
+              logInfo(s"Compressing Avro output using the $codecName codec at 
level $deflateLevel")
+              job.getConfiguration.setInt(AvroOutputFormat.DEFLATE_LEVEL_KEY, 
deflateLevel)
+            } else {
+              logInfo(s"Compressing Avro output using the $codecName codec")
+            }
+        }
+      case unknown =>
+        throw new SparkIllegalArgumentException(
+          "CODEC_SHORT_NAME_NOT_FOUND", Map("codecName" -> unknown))
     }
 
     new AvroOutputWriterFactory(dataSchema,
diff --git 
a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroCodecSuite.scala 
b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroCodecSuite.scala
index ec3753b84a55..933b3f989ef7 100644
--- 
a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroCodecSuite.scala
+++ 
b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroCodecSuite.scala
@@ -17,6 +17,9 @@
 
 package org.apache.spark.sql.avro
 
+import java.util.Locale
+
+import org.apache.spark.SparkIllegalArgumentException
 import org.apache.spark.sql.execution.datasources.FileSourceCodecSuite
 import org.apache.spark.sql.internal.SQLConf
 
@@ -27,19 +30,32 @@ class AvroCodecSuite extends FileSourceCodecSuite {
   override protected def availableCodecs =
     AvroCompressionCodec.values().map(_.lowerCaseName()).iterator.to(Seq)
 
-  availableCodecs.foreach { codec =>
+  (availableCodecs ++ availableCodecs.map(_.capitalize)).foreach { codec =>
     test(s"SPARK-46746: attach codec name to avro files - codec $codec") {
       withTable("avro_t") {
         sql(
           s"""CREATE TABLE avro_t
-             | USING $format OPTIONS('compression'='$codec')
-             | AS SELECT 1 as id
-             | """.stripMargin)
-        spark.table("avro_t")
-          .inputFiles.foreach { file =>
-            assert(file.endsWith(s"$codec.avro".stripPrefix("uncompressed")))
+             |USING $format OPTIONS('compression'='$codec')
+             |AS SELECT 1 as id""".stripMargin)
+        spark
+          .table("avro_t")
+          .inputFiles.foreach { f =>
+            
assert(f.endsWith(s"$codec.avro".toLowerCase(Locale.ROOT).stripPrefix("uncompressed")))
           }
       }
     }
   }
+
+  test("SPARK-46754: invalid compression codec name in avro table definition") 
{
+    checkError(
+      exception = intercept[SparkIllegalArgumentException](
+        sql(
+          s"""CREATE TABLE avro_t
+             |USING $format OPTIONS('compression'='unsupported')
+             |AS SELECT 1 as id""".stripMargin)),
+      errorClass = "CODEC_SHORT_NAME_NOT_FOUND",
+      sqlState = Some("42704"),
+      parameters = Map("codecName" -> "unsupported")
+    )
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to