This is an automated email from the ASF dual-hosted git repository.

maxgekk pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new a8734e3  [SPARK-36831][SQL] Support reading and writing ANSI intervals 
from/to CSV datasources
a8734e3 is described below

commit a8734e3f1695a3c436f65bbb1d54d1d02b0df33f
Author: Kousuke Saruta <saru...@oss.nttdata.com>
AuthorDate: Wed Sep 29 21:22:34 2021 +0300

    [SPARK-36831][SQL] Support reading and writing ANSI intervals from/to CSV 
datasources
    
    ### What changes were proposed in this pull request?
    
    This PR aims to support reading and writing ANSI intervals from/to CSV 
datasources.
    Aith this change, a interval data is written as a literal form like 
`INTERVAL '1-2' YEAR TO MONTH`.
    For the reading part, we need to specify the schema explicitly like:
    ```
    val readDF = spark.read.schema("col INTERVAL YEAR TO MONTH").csv(...)
    ```
    
    ### Why are the changes needed?
    
    For better usability. There should be no reason to prohibit from 
reading/writing ANSI intervals from/to CSV datasources.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    New test. It covers both V1 and V2 sources.
    
    Closes #34142 from sarutak/ansi-interval-csv-source.
    
    Authored-by: Kousuke Saruta <saru...@oss.nttdata.com>
    Signed-off-by: Max Gekk <max.g...@gmail.com>
---
 .../sql/execution/datasources/DataSource.scala      |  9 +++++++--
 .../execution/datasources/csv/CSVFileFormat.scala   |  2 --
 .../sql/execution/datasources/v2/csv/CSVTable.scala |  4 +---
 .../datasources/CommonFileDataSourceSuite.scala     |  2 +-
 .../sql/execution/datasources/csv/CSVSuite.scala    | 21 ++++++++++++++++++++-
 5 files changed, 29 insertions(+), 9 deletions(-)

diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
index 0707af4..be9a912 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
@@ -579,11 +579,16 @@ case class DataSource(
       checkEmptyGlobPath, checkFilesExist, enableGlobbing = globPaths)
   }
 
+  // TODO: Remove the Set below once all the built-in datasources support ANSI 
interval types
+  private val writeAllowedSources: Set[Class[_]] =
+    Set(classOf[ParquetFileFormat], classOf[CSVFileFormat])
+
   private def disallowWritingIntervals(
       dataTypes: Seq[DataType],
       forbidAnsiIntervals: Boolean): Unit = {
-    val isParquet = providingClass == classOf[ParquetFileFormat]
-    dataTypes.foreach(TypeUtils.invokeOnceForInterval(_, forbidAnsiIntervals 
|| !isParquet) {
+    val isWriteAllowedSource = writeAllowedSources(providingClass)
+    dataTypes.foreach(
+      TypeUtils.invokeOnceForInterval(_, forbidAnsiIntervals || 
!isWriteAllowedSource) {
       throw QueryCompilationErrors.cannotSaveIntervalIntoExternalStorageError()
     })
   }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala
index 8add63c..d40ad9d 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala
@@ -148,8 +148,6 @@ class CSVFileFormat extends TextBasedFileFormat with 
DataSourceRegister {
   override def equals(other: Any): Boolean = other.isInstanceOf[CSVFileFormat]
 
   override def supportDataType(dataType: DataType): Boolean = dataType match {
-    case _: AnsiIntervalType => false
-
     case _: AtomicType => true
 
     case udt: UserDefinedType[_] => supportDataType(udt.sqlType)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVTable.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVTable.scala
index 02601b3..839cd01 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVTable.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVTable.scala
@@ -26,7 +26,7 @@ import 
org.apache.spark.sql.connector.write.{LogicalWriteInfo, Write, WriteBuild
 import org.apache.spark.sql.execution.datasources.FileFormat
 import org.apache.spark.sql.execution.datasources.csv.CSVDataSource
 import org.apache.spark.sql.execution.datasources.v2.FileTable
-import org.apache.spark.sql.types.{AnsiIntervalType, AtomicType, DataType, 
StructType, UserDefinedType}
+import org.apache.spark.sql.types.{AtomicType, DataType, StructType, 
UserDefinedType}
 import org.apache.spark.sql.util.CaseInsensitiveStringMap
 
 case class CSVTable(
@@ -55,8 +55,6 @@ case class CSVTable(
     }
 
   override def supportsDataType(dataType: DataType): Boolean = dataType match {
-    case _: AnsiIntervalType => false
-
     case _: AtomicType => true
 
     case udt: UserDefinedType[_] => supportsDataType(udt.sqlType)
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/CommonFileDataSourceSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/CommonFileDataSourceSuite.scala
index 39e00e2..61a4ccd 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/CommonFileDataSourceSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/CommonFileDataSourceSuite.scala
@@ -36,7 +36,7 @@ trait CommonFileDataSourceSuite extends SQLHelper { self: 
AnyFunSuite =>
   protected def inputDataset: Dataset[_] = 
spark.createDataset(Seq("abc"))(Encoders.STRING)
 
   test(s"SPARK-36349: disallow saving of ANSI intervals to $dataSourceFormat") 
{
-    if (!Set("parquet").contains(dataSourceFormat.toLowerCase(Locale.ROOT))) {
+    if (!Set("csv", 
"parquet").contains(dataSourceFormat.toLowerCase(Locale.ROOT))) {
       Seq("INTERVAL '1' DAY", "INTERVAL '1' YEAR").foreach { i =>
         withTempPath { dir =>
           val errMsg = intercept[AnalysisException] {
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
index c46e84a..2a3f2b0 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
@@ -22,7 +22,7 @@ import java.nio.charset.{Charset, StandardCharsets, 
UnsupportedCharsetException}
 import java.nio.file.{Files, StandardOpenOption}
 import java.sql.{Date, Timestamp}
 import java.text.SimpleDateFormat
-import java.time.{Instant, LocalDate, LocalDateTime}
+import java.time.{Duration, Instant, LocalDate, LocalDateTime, Period}
 import java.util.Locale
 import java.util.zip.GZIPOutputStream
 
@@ -2518,6 +2518,25 @@ abstract class CSVSuite
       }
     }
   }
+
+  test("SPARK-36831: Support reading and writing ANSI intervals") {
+    Seq(
+      YearMonthIntervalType() -> ((i: Int) => Period.of(i, i, 0)),
+      DayTimeIntervalType() -> ((i: Int) => Duration.ofDays(i).plusSeconds(i))
+    ).foreach { case (it, f) =>
+      val data = (1 to 10).map(i => Row(i, f(i)))
+      val schema = StructType(Array(StructField("d", IntegerType, false),
+        StructField("i", it, false)))
+      withTempPath { file =>
+        val df = spark.createDataFrame(sparkContext.parallelize(data), schema)
+        df.write.csv(file.getCanonicalPath)
+        val df2 = spark.read.csv(file.getCanonicalPath)
+        checkAnswer(df2, df.select($"d".cast(StringType), 
$"i".cast(StringType)).collect().toSeq)
+        val df3 = spark.read.schema(schema).csv(file.getCanonicalPath)
+        checkAnswer(df3, df.collect().toSeq)
+      }
+    }
+  }
 }
 
 class CSVv1Suite extends CSVSuite {

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to