This is an automated email from the ASF dual-hosted git repository. maxgekk pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new a8734e3 [SPARK-36831][SQL] Support reading and writing ANSI intervals from/to CSV datasources a8734e3 is described below commit a8734e3f1695a3c436f65bbb1d54d1d02b0df33f Author: Kousuke Saruta <saru...@oss.nttdata.com> AuthorDate: Wed Sep 29 21:22:34 2021 +0300 [SPARK-36831][SQL] Support reading and writing ANSI intervals from/to CSV datasources ### What changes were proposed in this pull request? This PR aims to support reading and writing ANSI intervals from/to CSV datasources. Aith this change, a interval data is written as a literal form like `INTERVAL '1-2' YEAR TO MONTH`. For the reading part, we need to specify the schema explicitly like: ``` val readDF = spark.read.schema("col INTERVAL YEAR TO MONTH").csv(...) ``` ### Why are the changes needed? For better usability. There should be no reason to prohibit from reading/writing ANSI intervals from/to CSV datasources. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? New test. It covers both V1 and V2 sources. Closes #34142 from sarutak/ansi-interval-csv-source. Authored-by: Kousuke Saruta <saru...@oss.nttdata.com> Signed-off-by: Max Gekk <max.g...@gmail.com> --- .../sql/execution/datasources/DataSource.scala | 9 +++++++-- .../execution/datasources/csv/CSVFileFormat.scala | 2 -- .../sql/execution/datasources/v2/csv/CSVTable.scala | 4 +--- .../datasources/CommonFileDataSourceSuite.scala | 2 +- .../sql/execution/datasources/csv/CSVSuite.scala | 21 ++++++++++++++++++++- 5 files changed, 29 insertions(+), 9 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 0707af4..be9a912 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -579,11 +579,16 @@ case class DataSource( checkEmptyGlobPath, checkFilesExist, enableGlobbing = globPaths) } + // TODO: Remove the Set below once all the built-in datasources support ANSI interval types + private val writeAllowedSources: Set[Class[_]] = + Set(classOf[ParquetFileFormat], classOf[CSVFileFormat]) + private def disallowWritingIntervals( dataTypes: Seq[DataType], forbidAnsiIntervals: Boolean): Unit = { - val isParquet = providingClass == classOf[ParquetFileFormat] - dataTypes.foreach(TypeUtils.invokeOnceForInterval(_, forbidAnsiIntervals || !isParquet) { + val isWriteAllowedSource = writeAllowedSources(providingClass) + dataTypes.foreach( + TypeUtils.invokeOnceForInterval(_, forbidAnsiIntervals || !isWriteAllowedSource) { throw QueryCompilationErrors.cannotSaveIntervalIntoExternalStorageError() }) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala index 8add63c..d40ad9d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala @@ -148,8 +148,6 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { override def equals(other: Any): Boolean = other.isInstanceOf[CSVFileFormat] override def supportDataType(dataType: DataType): Boolean = dataType match { - case _: AnsiIntervalType => false - case _: AtomicType => true case udt: UserDefinedType[_] => supportDataType(udt.sqlType) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVTable.scala index 02601b3..839cd01 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVTable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVTable.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.connector.write.{LogicalWriteInfo, Write, WriteBuild import org.apache.spark.sql.execution.datasources.FileFormat import org.apache.spark.sql.execution.datasources.csv.CSVDataSource import org.apache.spark.sql.execution.datasources.v2.FileTable -import org.apache.spark.sql.types.{AnsiIntervalType, AtomicType, DataType, StructType, UserDefinedType} +import org.apache.spark.sql.types.{AtomicType, DataType, StructType, UserDefinedType} import org.apache.spark.sql.util.CaseInsensitiveStringMap case class CSVTable( @@ -55,8 +55,6 @@ case class CSVTable( } override def supportsDataType(dataType: DataType): Boolean = dataType match { - case _: AnsiIntervalType => false - case _: AtomicType => true case udt: UserDefinedType[_] => supportsDataType(udt.sqlType) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/CommonFileDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/CommonFileDataSourceSuite.scala index 39e00e2..61a4ccd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/CommonFileDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/CommonFileDataSourceSuite.scala @@ -36,7 +36,7 @@ trait CommonFileDataSourceSuite extends SQLHelper { self: AnyFunSuite => protected def inputDataset: Dataset[_] = spark.createDataset(Seq("abc"))(Encoders.STRING) test(s"SPARK-36349: disallow saving of ANSI intervals to $dataSourceFormat") { - if (!Set("parquet").contains(dataSourceFormat.toLowerCase(Locale.ROOT))) { + if (!Set("csv", "parquet").contains(dataSourceFormat.toLowerCase(Locale.ROOT))) { Seq("INTERVAL '1' DAY", "INTERVAL '1' YEAR").foreach { i => withTempPath { dir => val errMsg = intercept[AnalysisException] { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index c46e84a..2a3f2b0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -22,7 +22,7 @@ import java.nio.charset.{Charset, StandardCharsets, UnsupportedCharsetException} import java.nio.file.{Files, StandardOpenOption} import java.sql.{Date, Timestamp} import java.text.SimpleDateFormat -import java.time.{Instant, LocalDate, LocalDateTime} +import java.time.{Duration, Instant, LocalDate, LocalDateTime, Period} import java.util.Locale import java.util.zip.GZIPOutputStream @@ -2518,6 +2518,25 @@ abstract class CSVSuite } } } + + test("SPARK-36831: Support reading and writing ANSI intervals") { + Seq( + YearMonthIntervalType() -> ((i: Int) => Period.of(i, i, 0)), + DayTimeIntervalType() -> ((i: Int) => Duration.ofDays(i).plusSeconds(i)) + ).foreach { case (it, f) => + val data = (1 to 10).map(i => Row(i, f(i))) + val schema = StructType(Array(StructField("d", IntegerType, false), + StructField("i", it, false))) + withTempPath { file => + val df = spark.createDataFrame(sparkContext.parallelize(data), schema) + df.write.csv(file.getCanonicalPath) + val df2 = spark.read.csv(file.getCanonicalPath) + checkAnswer(df2, df.select($"d".cast(StringType), $"i".cast(StringType)).collect().toSeq) + val df3 = spark.read.schema(schema).csv(file.getCanonicalPath) + checkAnswer(df3, df.collect().toSeq) + } + } + } } class CSVv1Suite extends CSVSuite { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org