This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 9d1f4d3e7711 [SPARK-54557][SQL] Make CSV/JSON/XmlOptions and 
CSV/JSON/XmlInferSchema comparable
9d1f4d3e7711 is described below

commit 9d1f4d3e771176485df38ef484c924a814a8d3e3
Author: mihailoale-db <[email protected]>
AuthorDate: Wed Dec 3 07:21:04 2025 +0900

    [SPARK-54557][SQL] Make CSV/JSON/XmlOptions and CSV/JSON/XmlInferSchema 
comparable
    
    ### What changes were proposed in this pull request?
    In this PR I propose to make `XmlOptions` and `XmlInferSchema` comparable.
    
    ### Why are the changes needed?
    In order to be able to compare them while working on the single-pass 
implementation (dual-runs).
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    Existing tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No.
    
    Closes #53268 from mihailoale-db/xmlequalsimplement.
    
    Authored-by: mihailoale-db <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 .../spark/sql/catalyst/csv/CSVInferSchema.scala    |  8 +++++++
 .../apache/spark/sql/catalyst/csv/CSVOptions.scala | 22 +++++++++++++++++--
 .../spark/sql/catalyst/json/JSONOptions.scala      | 20 +++++++++++++++--
 .../spark/sql/catalyst/json/JsonInferSchema.scala  |  9 +++++++-
 .../spark/sql/catalyst/xml/XmlInferSchema.scala    | 15 ++++++++++++-
 .../apache/spark/sql/catalyst/xml/XmlOptions.scala | 25 +++++++++++++++++++---
 6 files changed, 90 insertions(+), 9 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala
index 5444ab684586..0cc11ce6bb89 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala
@@ -68,6 +68,14 @@ class CSVInferSchema(val options: CSVOptions) extends 
Serializable {
 
   private val isDefaultNTZ = SQLConf.get.timestampType == TimestampNTZType
 
+  override def equals(obj: Any): Boolean = obj match {
+    case other: CSVInferSchema =>
+      options == other.options
+    case _ => false
+  }
+
+  override def hashCode(): Int = options.hashCode()
+
   /**
    * Similar to the JSON schema inference
    *     1. Infer type of each row
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala
index f3862a600022..5f1560060850 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala
@@ -34,8 +34,8 @@ import org.apache.spark.sql.types.StructType
 class CSVOptions(
     @transient val parameters: CaseInsensitiveMap[String],
     val columnPruning: Boolean,
-    defaultTimeZoneId: String,
-    defaultColumnNameOfCorruptRecord: String)
+    private val defaultTimeZoneId: String,
+    private val defaultColumnNameOfCorruptRecord: String)
   extends FileSourceOptions(parameters) with Logging {
 
   import CSVOptions._
@@ -63,6 +63,24 @@ class CSVOptions(
         defaultColumnNameOfCorruptRecord)
   }
 
+  override def equals(obj: Any): Boolean = obj match {
+    case other: CSVOptions =>
+      (parameters == null && other.parameters == null ||
+      parameters != null && parameters == other.parameters) &&
+      columnPruning == other.columnPruning &&
+      defaultTimeZoneId == other.defaultTimeZoneId &&
+      defaultColumnNameOfCorruptRecord == 
other.defaultColumnNameOfCorruptRecord
+    case _ => false
+  }
+
+  override def hashCode(): Int = {
+    var result = Option(parameters).map(_.hashCode()).getOrElse(0)
+    result = 31 * result + (if (columnPruning) 1 else 0)
+    result = 31 * result + defaultTimeZoneId.hashCode()
+    result = 31 * result + defaultColumnNameOfCorruptRecord.hashCode()
+    result
+  }
+
   private def getChar(paramName: String, default: Char): Char = {
     val paramValue = parameters.get(paramName)
     paramValue match {
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala
index a8f98901e6ab..d90491c2cf3d 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala
@@ -36,8 +36,8 @@ import org.apache.spark.sql.internal.{LegacyBehaviorPolicy, 
SQLConf}
  */
 class JSONOptions(
     @transient val parameters: CaseInsensitiveMap[String],
-    defaultTimeZoneId: String,
-    defaultColumnNameOfCorruptRecord: String)
+    private val defaultTimeZoneId: String,
+    private val defaultColumnNameOfCorruptRecord: String)
   extends FileSourceOptions(parameters) with Logging  {
 
   import JSONOptions._
@@ -156,6 +156,22 @@ class JSONOptions(
   protected def checkedEncoding(enc: String): String =
     CharsetProvider.forName(enc, caller = "JSONOptions").name()
 
+  override def equals(obj: Any): Boolean = obj match {
+    case other: JSONOptions =>
+      (parameters == null && other.parameters == null ||
+      parameters != null && parameters == other.parameters) &&
+      defaultTimeZoneId == other.defaultTimeZoneId &&
+      defaultColumnNameOfCorruptRecord == 
other.defaultColumnNameOfCorruptRecord
+    case _ => false
+  }
+
+  override def hashCode(): Int = {
+    var result = Option(parameters).map(_.hashCode()).getOrElse(0)
+    result = 31 * result + defaultTimeZoneId.hashCode()
+    result = 31 * result + defaultColumnNameOfCorruptRecord.hashCode()
+    result
+  }
+
   /**
    * Standard encoding (charset) name. For example UTF-8, UTF-16LE and 
UTF-32BE.
    * If the encoding is not specified (None) in read, it will be detected 
automatically
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala
index b509c55ed6a3..5587bb3d30f0 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala
@@ -39,7 +39,7 @@ import org.apache.spark.unsafe.types.UTF8String
 import org.apache.spark.util.ArrayImplicits._
 import org.apache.spark.util.Utils
 
-class JsonInferSchema(options: JSONOptions) extends Serializable with Logging {
+class JsonInferSchema(private val options: JSONOptions) extends Serializable 
with Logging {
 
   private val decimalParser = ExprUtils.getDecimalParser(options.locale)
 
@@ -61,6 +61,13 @@ class JsonInferSchema(options: JSONOptions) extends 
Serializable with Logging {
   private val isDefaultNTZ = SQLConf.get.timestampType == TimestampNTZType
   private val legacyMode = SQLConf.get.legacyTimeParserPolicy == 
LegacyBehaviorPolicy.LEGACY
 
+  override def equals(obj: Any): Boolean = obj match {
+    case other: JsonInferSchema => options == other.options
+    case _ => false
+  }
+
+  override def hashCode(): Int = options.hashCode()
+
   private def handleJsonErrorsByParseMode(parseMode: ParseMode,
       columnNameOfCorruptRecord: String, e: Throwable): Option[StructType] = {
     parseMode match {
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlInferSchema.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlInferSchema.scala
index d328c6c226a9..a0139e98a266 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlInferSchema.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlInferSchema.scala
@@ -45,7 +45,7 @@ import org.apache.spark.sql.internal.{LegacyBehaviorPolicy, 
SQLConf}
 import org.apache.spark.sql.types._
 import org.apache.spark.util.SparkErrorUtils
 
-class XmlInferSchema(options: XmlOptions, caseSensitive: Boolean)
+class XmlInferSchema(private val options: XmlOptions, private val 
caseSensitive: Boolean)
     extends Serializable
     with Logging {
 
@@ -73,6 +73,19 @@ class XmlInferSchema(options: XmlOptions, caseSensitive: 
Boolean)
     legacyFormat = FAST_DATE_FORMAT,
     isParsing = true)
 
+  override def equals(obj: Any): Boolean = obj match {
+    case other: XmlInferSchema =>
+      options == other.options &&
+      caseSensitive == other.caseSensitive
+    case _ => false
+  }
+
+  override def hashCode(): Int = {
+    var result = options.hashCode()
+    result = 31 * result + (if (caseSensitive) 1 else 0)
+    result
+  }
+
   private def handleXmlErrorsByParseMode(
       parser: XMLEventReader,
       parseMode: ParseMode,
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlOptions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlOptions.scala
index a08ac860970e..92740df452e6 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlOptions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlOptions.scala
@@ -32,9 +32,9 @@ import org.apache.spark.sql.internal.{LegacyBehaviorPolicy, 
SQLConf}
  */
 class XmlOptions(
     val parameters: CaseInsensitiveMap[String],
-    defaultTimeZoneId: String,
-    defaultColumnNameOfCorruptRecord: String,
-    rowTagRequired: Boolean)
+    private val defaultTimeZoneId: String,
+    private val defaultColumnNameOfCorruptRecord: String,
+    private val rowTagRequired: Boolean)
   extends FileSourceOptions(parameters) with Logging {
 
   import XmlOptions._
@@ -51,6 +51,25 @@ class XmlOptions(
       rowTagRequired)
   }
 
+
+  override def equals(obj: Any): Boolean = obj match {
+    case other: XmlOptions =>
+      (parameters == null && other.parameters == null ||
+      parameters != null && parameters == other.parameters) &&
+      defaultTimeZoneId == other.defaultTimeZoneId &&
+      defaultColumnNameOfCorruptRecord == 
other.defaultColumnNameOfCorruptRecord &&
+      rowTagRequired == other.rowTagRequired
+    case _ => false
+  }
+
+  override def hashCode(): Int = {
+    var result = Option(parameters).map(_.hashCode()).getOrElse(0)
+    result = 31 * result + defaultTimeZoneId.hashCode()
+    result = 31 * result + defaultColumnNameOfCorruptRecord.hashCode()
+    result = 31 * result + (if (rowTagRequired) 1 else 0)
+    result
+  }
+
   private def getBool(paramName: String, default: Boolean = false): Boolean = {
     val param = parameters.getOrElse(paramName, default.toString)
     if (param == null) {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to