MaxGekk closed pull request #23257: [SPARK-26310][SQL] Verify applicability of 
JSON options
URL: https://github.com/apache/spark/pull/23257
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala
index e0cab537ce1c6..44a4d93c0ff8a 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala
@@ -569,7 +569,7 @@ case class JsonToStructs(
 
   val nameOfCorruptRecord = 
SQLConf.get.getConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD)
   @transient lazy val parser = {
-    val parsedOptions = new JSONOptions(options, timeZoneId.get, 
nameOfCorruptRecord)
+    val parsedOptions = new JSONOptionsInRead(options, timeZoneId.get, 
nameOfCorruptRecord)
     val mode = parsedOptions.parseMode
     if (mode != PermissiveMode && mode != FailFastMode) {
       throw new IllegalArgumentException(s"from_json() doesn't support the 
${mode.name} mode. " +
@@ -660,7 +660,7 @@ case class StructsToJson(
 
   @transient
   lazy val gen = new JacksonGenerator(
-    inputSchema, writer, new JSONOptions(options, timeZoneId.get))
+    inputSchema, writer, new JSONOptionsInWrite(options, timeZoneId.get))
 
   @transient
   lazy val inputSchema = child.dataType
@@ -764,7 +764,7 @@ case class SchemaOfJson(
   override def nullable: Boolean = false
 
   @transient
-  private lazy val jsonOptions = new JSONOptions(options, "UTC")
+  private lazy val jsonOptions = new JSONOptionsInRead(options, "UTC")
 
   @transient
   private lazy val jsonFactory = {
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala
index 1ec9d5093a789..3d7facd6ce9c8 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala
@@ -24,13 +24,14 @@ import com.fasterxml.jackson.core.{JsonFactory, JsonParser}
 
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.catalyst.util._
+import org.apache.spark.sql.internal.SQLConf
 
 /**
  * Options for parsing JSON data into Spark SQL rows.
  *
  * Most of these map directly to Jackson's internal options, specified in 
[[JsonParser.Feature]].
  */
-private[sql] class JSONOptions(
+private[sql] abstract class JSONOptions(
     @transient val parameters: CaseInsensitiveMap[String],
     defaultTimeZoneId: String,
     defaultColumnNameOfCorruptRecord: String)
@@ -134,6 +135,17 @@ private[sql] class JSONOptions(
       allowBackslashEscapingAnyCharacter)
     factory.configure(JsonParser.Feature.ALLOW_UNQUOTED_CONTROL_CHARS, 
allowUnquotedControlChars)
   }
+
+  def notApplicableOptions: Set[String]
+  def checkOptions(where: String): Unit = {
+    val wrongOptions = notApplicableOptions.filter(parameters.contains(_))
+    if (!wrongOptions.isEmpty && SQLConf.get.verifyDataSourceOptions) {
+      // scalastyle:off throwerror
+      throw new IllegalArgumentException(
+        s"""The JSON options are not applicable $where : 
${wrongOptions.mkString(", ")}.""")
+      // scalastyle:on throwerror
+    }
+  }
 }
 
 private[sql] class JSONOptionsInRead(
@@ -164,6 +176,11 @@ private[sql] class JSONOptionsInRead(
 
     enc
   }
+
+  override def notApplicableOptions: Set[String] = Set(
+    "compression",
+    "pretty")
+  checkOptions("in read")
 }
 
 private[sql] object JSONOptionsInRead {
@@ -178,3 +195,34 @@ private[sql] object JSONOptionsInRead {
     Charset.forName("UTF-32")
   )
 }
+
+private[sql] class JSONOptionsInWrite(
+    @transient override val parameters: CaseInsensitiveMap[String],
+    defaultTimeZoneId: String)
+  extends JSONOptions(parameters, defaultTimeZoneId, "") {
+
+  def this(
+      parameters: Map[String, String],
+      defaultTimeZoneId: String) = {
+    this(
+      CaseInsensitiveMap(parameters),
+      defaultTimeZoneId)
+  }
+
+  override def notApplicableOptions: Set[String] = Set(
+    "samplingRatio",
+    "primitivesAsString",
+    "prefersDecimal",
+    "allowComments",
+    "allowUnquotedFieldNames",
+    "allowSingleQuotes",
+    "allowNumericLeadingZeros",
+    "allowNonNumericNumbers",
+    "allowBackslashEscapingAnyCharacter",
+    "allowUnquotedControlChars",
+    "mode",
+    "columnNameOfCorruptRecord",
+    "dropFieldIfAllNull",
+    "multiLine")
+  checkOptions("in write")
+}
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index c1b885a72ad3e..d569e38f65fcc 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -1625,6 +1625,14 @@ object SQLConf {
         "a SparkConf entry.")
       .booleanConf
       .createWithDefault(true)
+
+  val VERIFY_DATASOURCE_OPTIONS = 
buildConf("spark.sql.verifyDataSourceOptions")
+    .doc("Options passed to datasource are checked that rather they could be 
applied in read or " +
+      "in write when this configuration property is set to true. For example, 
If an option can " +
+      "be applied only in read but applied in write, an exception is raised. " 
+
+      "To disable the verification, set it to false.")
+    .booleanConf
+    .createWithDefault(true)
 }
 
 /**
@@ -1810,6 +1818,8 @@ class SQLConf extends Serializable with Logging {
 
   def fastHashAggregateRowMaxCapacityBit: Int = 
getConf(FAST_HASH_AGGREGATE_MAX_ROWS_CAPACITY_BIT)
 
+  def verifyDataSourceOptions: Boolean = getConf(VERIFY_DATASOURCE_OPTIONS)
+
   /**
    * Returns the [[Resolver]] for the current configuration, which can be used 
to determine if two
    * identifiers are equal.
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/json/JacksonGeneratorSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/json/JacksonGeneratorSuite.scala
index 9b27490ed0e35..799a89ff0334c 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/json/JacksonGeneratorSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/json/JacksonGeneratorSuite.scala
@@ -27,7 +27,7 @@ import org.apache.spark.sql.types._
 class JacksonGeneratorSuite extends SparkFunSuite {
 
   val gmtId = DateTimeUtils.TimeZoneGMT.getID
-  val option = new JSONOptions(Map.empty, gmtId)
+  val option = new JSONOptionsInRead(Map.empty, gmtId)
 
   test("initial with StructType and write out a row") {
     val dataType = StructType(StructField("a", IntegerType) :: Nil)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
index ce8e4c8f5b82b..bbc09ea10b068 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
@@ -30,7 +30,7 @@ import org.apache.spark.internal.Logging
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.csv.{CSVHeaderChecker, CSVOptions, 
UnivocityParser}
 import org.apache.spark.sql.catalyst.expressions.ExprUtils
-import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, 
JSONOptions}
+import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, 
JSONOptionsInRead}
 import org.apache.spark.sql.catalyst.util.FailureSafeParser
 import org.apache.spark.sql.execution.command.DDLUtils
 import org.apache.spark.sql.execution.datasources.DataSource
@@ -440,7 +440,7 @@ class DataFrameReader private[sql](sparkSession: 
SparkSession) extends Logging {
    * @since 2.2.0
    */
   def json(jsonDataset: Dataset[String]): DataFrame = {
-    val parsedOptions = new JSONOptions(
+    val parsedOptions = new JSONOptionsInRead(
       extraOptions.toMap,
       sparkSession.sessionState.conf.sessionLocalTimeZone,
       sparkSession.sessionState.conf.columnNameOfCorruptRecord)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 44cada086489a..047fd34b3b88b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -39,7 +39,7 @@ import org.apache.spark.sql.catalyst.catalog.HiveTableRelation
 import org.apache.spark.sql.catalyst.encoders._
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.codegen.GenerateSafeProjection
-import org.apache.spark.sql.catalyst.json.{JacksonGenerator, JSONOptions}
+import org.apache.spark.sql.catalyst.json.{JacksonGenerator, 
JSONOptionsInWrite}
 import org.apache.spark.sql.catalyst.optimizer.CombineUnions
 import org.apache.spark.sql.catalyst.parser.{ParseException, ParserUtils}
 import org.apache.spark.sql.catalyst.plans._
@@ -3120,7 +3120,7 @@ class Dataset[T] private[sql](
       val writer = new CharArrayWriter()
       // create the Generator without separator inserted between 2 records
       val gen = new JacksonGenerator(rowSchema, writer,
-        new JSONOptions(Map.empty[String, String], sessionLocalTimeZone))
+        new JSONOptionsInWrite(Map.empty[String, String], 
sessionLocalTimeZone))
 
       new Iterator[String] {
         override def hasNext: Boolean = iter.hasNext
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala
index 40f55e7068010..e3b13038012a0 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala
@@ -67,10 +67,9 @@ class JsonFileFormat extends TextBasedFileFormat with 
DataSourceRegister {
       options: Map[String, String],
       dataSchema: StructType): OutputWriterFactory = {
     val conf = job.getConfiguration
-    val parsedOptions = new JSONOptions(
+    val parsedOptions = new JSONOptionsInWrite(
       options,
-      sparkSession.sessionState.conf.sessionLocalTimeZone,
-      sparkSession.sessionState.conf.columnNameOfCorruptRecord)
+      sparkSession.sessionState.conf.sessionLocalTimeZone)
     parsedOptions.compressionCodec.foreach { codec =>
       CompressionCodecs.setCodecConfiguration(conf, codec)
     }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonParsingOptionsSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonParsingOptionsSuite.scala
index 316c5183fddf1..a63b41914f552 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonParsingOptionsSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonParsingOptionsSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.datasources.json
 
 import org.apache.spark.sql.QueryTest
 import org.apache.spark.sql.catalyst.json.JSONOptions
+import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.SharedSQLContext
 
 /**
@@ -135,4 +136,22 @@ class JsonParsingOptionsSuite extends QueryTest with 
SharedSQLContext {
     assert(df.first().getString(0) == "Cazen Lee")
     assert(df.first().getString(1) == "$10")
   }
+
+  test("verify options") {
+    withTempPath { dir =>
+      def invalidOptionUsage: Unit = {
+        val ds = Seq("""{"a": "b"}""").toDS()
+        ds.write.option("dropFieldIfAllNull", true).json(dir.getCanonicalPath)
+      }
+      val exception = intercept[IllegalArgumentException] {
+        invalidOptionUsage
+      }
+      assert(exception.getMessage.contains(
+        "The JSON options are not applicable in write : dropFieldIfAllNull"))
+
+      withSQLConf(SQLConf.VERIFY_DATASOURCE_OPTIONS.key -> "false") {
+        invalidOptionUsage
+      }
+    }
+  }
 }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
index 78debb5731116..27f549bf36e20 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
@@ -68,7 +68,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with 
TestJsonData {
         generator.flush()
       }
 
-      val dummyOption = new JSONOptions(options, 
SQLConf.get.sessionLocalTimeZone)
+      val dummyOption = new JSONOptionsInRead(options, 
SQLConf.get.sessionLocalTimeZone)
       val dummySchema = StructType(Seq.empty)
       val parser = new JacksonParser(dummySchema, dummyOption, 
allowArrayAsStructs = true)
 
@@ -1384,7 +1384,7 @@ class JsonSuite extends QueryTest with SharedSQLContext 
with TestJsonData {
 
   test("SPARK-6245 JsonInferSchema.infer on empty RDD") {
     // This is really a test that it doesn't throw an exception
-    val options = new JSONOptions(Map.empty[String, String], "GMT")
+    val options = new JSONOptionsInRead(Map.empty[String, String], "GMT")
     val emptySchema = new JsonInferSchema(options).infer(
       empty.rdd,
       CreateJacksonParser.string)
@@ -1411,7 +1411,7 @@ class JsonSuite extends QueryTest with SharedSQLContext 
with TestJsonData {
   }
 
   test("SPARK-8093 Erase empty structs") {
-    val options = new JSONOptions(Map.empty[String, String], "GMT")
+    val options = new JSONOptionsInRead(Map.empty[String, String], "GMT")
     val emptySchema = new JsonInferSchema(options).infer(
       emptyRecords.rdd,
       CreateJacksonParser.string)
@@ -2235,7 +2235,6 @@ class JsonSuite extends QueryTest with SharedSQLContext 
with TestJsonData {
         val ds = spark.createDataset(Seq(("a", 1))).repartition(1)
         ds.write
           .option("encoding", encoding)
-          .option("multiline", false)
           .json(path.getCanonicalPath)
         val jsonFiles = path.listFiles().filter(_.getName.endsWith("json"))
         jsonFiles.foreach { jsonFile =>


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to