This is an automated email from the ASF dual-hosted git repository.
maxgekk pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new f853afd [SPARK-36931][SQL] Support reading and writing ANSI intervals
from/to ORC datasources
f853afd is described below
commit f853afdc035273f772dc47f5476be6cf205d0941
Author: Kousuke Saruta <[email protected]>
AuthorDate: Fri Oct 8 10:49:11 2021 +0300
[SPARK-36931][SQL] Support reading and writing ANSI intervals from/to ORC
datasources
### What changes were proposed in this pull request?
This PR aims to support reading and writing ANSI intervals from/to ORC
datasources.
year-month and day-time intervals are mapped to ORC's `int` and `bigint`
respectively,
To preserve the Catalyst's types, this change adds
`spark.sql.catalyst.type` attribute for each ORC's type information.
The value of the attribute is the value returned by
`YearMonthIntervalType.typeName` or `DayTimeIntervalType.typeName`.
### Why are the changes needed?
For better usability. There should be no reason to prohibit from
reading/writing ANSI intervals from/to ORC datasources.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
New tests.
Closes #34184 from sarutak/ansi-interval-orc-source.
Authored-by: Kousuke Saruta <[email protected]>
Signed-off-by: Max Gekk <[email protected]>
---
.../sql/execution/datasources/DataSource.scala | 3 +-
.../datasources/orc/OrcDeserializer.scala | 10 ++--
.../execution/datasources/orc/OrcFileFormat.scala | 8 ++-
.../datasources/orc/OrcOutputWriter.scala | 1 +
.../execution/datasources/orc/OrcSerializer.scala | 4 +-
.../sql/execution/datasources/orc/OrcUtils.scala | 62 +++++++++++++++++++++-
.../execution/datasources/v2/orc/OrcTable.scala | 2 -
.../datasources/CommonFileDataSourceSuite.scala | 2 +-
.../execution/datasources/orc/OrcSourceSuite.scala | 49 ++++++++++++++++-
9 files changed, 123 insertions(+), 18 deletions(-)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
index 32913c6..9936126 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
@@ -581,7 +581,8 @@ case class DataSource(
// TODO: Remove the Set below once all the built-in datasources support ANSI
interval types
private val writeAllowedSources: Set[Class[_]] =
- Set(classOf[ParquetFileFormat], classOf[CSVFileFormat],
classOf[JsonFileFormat])
+ Set(classOf[ParquetFileFormat], classOf[CSVFileFormat],
+ classOf[JsonFileFormat], classOf[OrcFileFormat])
private def disallowWritingIntervals(
dataTypes: Seq[DataType],
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala
index fa8977f..1476083 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala
@@ -86,10 +86,10 @@ class OrcDeserializer(
case ShortType => (ordinal, value) =>
updater.setShort(ordinal, value.asInstanceOf[ShortWritable].get)
- case IntegerType => (ordinal, value) =>
+ case IntegerType | _: YearMonthIntervalType => (ordinal, value) =>
updater.setInt(ordinal, value.asInstanceOf[IntWritable].get)
- case LongType => (ordinal, value) =>
+ case LongType | _: DayTimeIntervalType => (ordinal, value) =>
updater.setLong(ordinal, value.asInstanceOf[LongWritable].get)
case FloatType => (ordinal, value) =>
@@ -197,8 +197,10 @@ class OrcDeserializer(
case BooleanType => UnsafeArrayData.fromPrimitiveArray(new
Array[Boolean](length))
case ByteType => UnsafeArrayData.fromPrimitiveArray(new
Array[Byte](length))
case ShortType => UnsafeArrayData.fromPrimitiveArray(new
Array[Short](length))
- case IntegerType => UnsafeArrayData.fromPrimitiveArray(new
Array[Int](length))
- case LongType => UnsafeArrayData.fromPrimitiveArray(new
Array[Long](length))
+ case IntegerType | _: YearMonthIntervalType =>
+ UnsafeArrayData.fromPrimitiveArray(new Array[Int](length))
+ case LongType | _: DayTimeIntervalType =>
+ UnsafeArrayData.fromPrimitiveArray(new Array[Long](length))
case FloatType => UnsafeArrayData.fromPrimitiveArray(new
Array[Float](length))
case DoubleType => UnsafeArrayData.fromPrimitiveArray(new
Array[Double](length))
case _ => new GenericArrayData(new Array[Any](length))
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala
index c4ffdb4..26af2c3 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala
@@ -27,7 +27,7 @@ import org.apache.hadoop.mapreduce._
import org.apache.hadoop.mapreduce.lib.input.FileSplit
import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl
import org.apache.orc.{OrcUtils => _, _}
-import org.apache.orc.OrcConf.{COMPRESS, MAPRED_OUTPUT_SCHEMA}
+import org.apache.orc.OrcConf.COMPRESS
import org.apache.orc.mapred.OrcStruct
import org.apache.orc.mapreduce._
@@ -45,6 +45,8 @@ import org.apache.spark.util.{SerializableConfiguration,
Utils}
private[sql] object OrcFileFormat {
def getQuotedSchemaString(dataType: DataType): String = dataType match {
+ case _: DayTimeIntervalType => LongType.catalogString
+ case _: YearMonthIntervalType => IntegerType.catalogString
case _: AtomicType => dataType.catalogString
case StructType(fields) =>
fields.map(f => s"`${f.name}`:${getQuotedSchemaString(f.dataType)}")
@@ -90,8 +92,6 @@ class OrcFileFormat
val conf = job.getConfiguration
- conf.set(MAPRED_OUTPUT_SCHEMA.getAttribute,
OrcFileFormat.getQuotedSchemaString(dataSchema))
-
conf.set(COMPRESS.getAttribute, orcOptions.compressionCodec)
conf.asInstanceOf[JobConf]
@@ -233,8 +233,6 @@ class OrcFileFormat
}
override def supportDataType(dataType: DataType): Boolean = dataType match {
- case _: AnsiIntervalType => false
-
case _: AtomicType => true
case st: StructType => st.forall { f => supportDataType(f.dataType) }
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOutputWriter.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOutputWriter.scala
index 6f21573..fe057e0 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOutputWriter.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOutputWriter.scala
@@ -44,6 +44,7 @@ private[sql] class OrcOutputWriter(
}
val filename = orcOutputFormat.getDefaultWorkFile(context, ".orc")
val options = OrcMapRedOutputFormat.buildOptions(context.getConfiguration)
+ options.setSchema(OrcUtils.orcTypeDescription(dataSchema))
val writer = OrcFile.createWriter(filename, options)
val recordWriter = new OrcMapreduceRecordWriter[OrcStruct](writer)
OrcUtils.addSparkVersionMetadata(writer)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala
index ac32be2..9a1eb8a 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala
@@ -88,7 +88,7 @@ class OrcSerializer(dataSchema: StructType) {
(getter, ordinal) => new ShortWritable(getter.getShort(ordinal))
}
- case IntegerType =>
+ case IntegerType | _: YearMonthIntervalType =>
if (reuseObj) {
val result = new IntWritable()
(getter, ordinal) =>
@@ -99,7 +99,7 @@ class OrcSerializer(dataSchema: StructType) {
}
- case LongType =>
+ case LongType | _: DayTimeIntervalType =>
if (reuseObj) {
val result = new LongWritable()
(getter, ordinal) =>
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala
index ec57375..475448a 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala
@@ -50,6 +50,8 @@ object OrcUtils extends Logging {
"LZ4" -> ".lz4",
"LZO" -> ".lzo")
+ val CATALYST_TYPE_ATTRIBUTE_NAME = "spark.sql.catalyst.type"
+
def listOrcFiles(pathStr: String, conf: Configuration): Seq[Path] = {
val origPath = new Path(pathStr)
val fs = origPath.getFileSystem(conf)
@@ -93,7 +95,13 @@ object OrcUtils extends Logging {
case Category.STRUCT => toStructType(orcType)
case Category.LIST => toArrayType(orcType)
case Category.MAP => toMapType(orcType)
- case _ => CatalystSqlParser.parseDataType(orcType.toString)
+ case _ =>
+ val catalystTypeAttrValue =
orcType.getAttributeValue(CATALYST_TYPE_ATTRIBUTE_NAME)
+ if (catalystTypeAttrValue != null) {
+ CatalystSqlParser.parseDataType(catalystTypeAttrValue)
+ } else {
+ CatalystSqlParser.parseDataType(orcType.toString)
+ }
}
}
@@ -265,9 +273,61 @@ object OrcUtils extends Logging {
s"array<${orcTypeDescriptionString(a.elementType)}>"
case m: MapType =>
s"map<${orcTypeDescriptionString(m.keyType)},${orcTypeDescriptionString(m.valueType)}>"
+ case _: DayTimeIntervalType => LongType.catalogString
+ case _: YearMonthIntervalType => IntegerType.catalogString
case _ => dt.catalogString
}
+ def orcTypeDescription(dt: DataType): TypeDescription = {
+ def getInnerTypeDecription(dt: DataType): Option[TypeDescription] = {
+ dt match {
+ case y: YearMonthIntervalType =>
+ val typeDesc = orcTypeDescription(IntegerType)
+ typeDesc.setAttribute(
+ CATALYST_TYPE_ATTRIBUTE_NAME, y.typeName)
+ Some(typeDesc)
+ case d: DayTimeIntervalType =>
+ val typeDesc = orcTypeDescription(LongType)
+ typeDesc.setAttribute(
+ CATALYST_TYPE_ATTRIBUTE_NAME, d.typeName)
+ Some(typeDesc)
+ case _ => None
+ }
+ }
+
+ dt match {
+ case s: StructType =>
+ val result = new TypeDescription(TypeDescription.Category.STRUCT)
+ s.fields.foreach { f =>
+ getInnerTypeDecription(f.dataType) match {
+ case Some(t) => result.addField(f.name, t)
+ case None => result.addField(f.name,
orcTypeDescription(f.dataType))
+ }
+ }
+ result
+ case a: ArrayType =>
+ val result = new TypeDescription(TypeDescription.Category.LIST)
+ getInnerTypeDecription(a.elementType) match {
+ case Some(t) => result.addChild(t)
+ case None => result.addChild(orcTypeDescription(a.elementType))
+ }
+ result
+ case m: MapType =>
+ val result = new TypeDescription(TypeDescription.Category.MAP)
+ getInnerTypeDecription(m.keyType) match {
+ case Some(t) => result.addChild(t)
+ case None => result.addChild(orcTypeDescription(m.keyType))
+ }
+ getInnerTypeDecription(m.valueType) match {
+ case Some(t) => result.addChild(t)
+ case None => result.addChild(orcTypeDescription(m.valueType))
+ }
+ result
+ case other =>
+ TypeDescription.fromString(other.catalogString)
+ }
+ }
+
/**
* Returns the result schema to read from ORC file. In addition, It sets
* the schema string to 'orc.mapred.input.schema' so ORC reader can use
later.
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcTable.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcTable.scala
index 628b0a1..9cc4525 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcTable.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcTable.scala
@@ -49,8 +49,6 @@ case class OrcTable(
}
override def supportsDataType(dataType: DataType): Boolean = dataType match {
- case _: AnsiIntervalType => false
-
case _: AtomicType => true
case st: StructType => st.forall { f => supportsDataType(f.dataType) }
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/CommonFileDataSourceSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/CommonFileDataSourceSuite.scala
index 28d0967..854463d3 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/CommonFileDataSourceSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/CommonFileDataSourceSuite.scala
@@ -36,7 +36,7 @@ trait CommonFileDataSourceSuite extends SQLHelper { self:
AnyFunSuite =>
protected def inputDataset: Dataset[_] =
spark.createDataset(Seq("abc"))(Encoders.STRING)
test(s"SPARK-36349: disallow saving of ANSI intervals to $dataSourceFormat")
{
- if (!Set("parquet", "csv",
"json").contains(dataSourceFormat.toLowerCase(Locale.ROOT))) {
+ if (!Set("parquet", "csv", "json",
"orc").contains(dataSourceFormat.toLowerCase(Locale.ROOT))) {
Seq("INTERVAL '1' DAY", "INTERVAL '1' YEAR").foreach { i =>
withTempPath { dir =>
val errMsg = intercept[AnalysisException] {
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala
index d077814..8ffccd9 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.datasources.orc
import java.io.File
import java.nio.charset.StandardCharsets.UTF_8
import java.sql.{Date, Timestamp}
+import java.time.{Duration, Period}
import java.util.Locale
import org.apache.hadoop.conf.Configuration
@@ -35,13 +36,14 @@ import org.apache.spark.{SPARK_VERSION_SHORT, SparkConf,
SparkException}
import org.apache.spark.sql.{Row, SPARK_VERSION_METADATA_KEY}
import org.apache.spark.sql.execution.datasources.{CommonFileDataSourceSuite,
SchemaMergeUtils}
import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.sql.test.{SharedSparkSession, SQLTestUtilsBase}
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
case class OrcData(intField: Int, stringField: String)
-abstract class OrcSuite extends OrcTest with BeforeAndAfterAll with
CommonFileDataSourceSuite {
+abstract class OrcSuite
+ extends OrcTest with BeforeAndAfterAll with CommonFileDataSourceSuite with
SQLTestUtilsBase {
import testImplicits._
override protected def dataSourceFormat = "orc"
@@ -806,6 +808,49 @@ abstract class OrcSourceSuite extends OrcSuite with
SharedSparkSession {
StructField("456", StringType) :: Nil))))))
}
}
+
+ Seq(true, false).foreach { vecReaderEnabled =>
+ Seq(true, false).foreach { vecReaderNestedColEnabled =>
+ test("SPARK-36931: Support reading and writing ANSI intervals (" +
+ s"${SQLConf.ORC_VECTORIZED_READER_ENABLED.key}=$vecReaderEnabled, " +
+
s"${SQLConf.ORC_VECTORIZED_READER_NESTED_COLUMN_ENABLED.key}=$vecReaderNestedColEnabled)")
{
+
+ withSQLConf(
+ SQLConf.ORC_VECTORIZED_READER_ENABLED.key ->
+ vecReaderEnabled.toString,
+ SQLConf.ORC_VECTORIZED_READER_NESTED_COLUMN_ENABLED.key ->
+ vecReaderNestedColEnabled.toString) {
+ Seq(
+ YearMonthIntervalType() -> ((i: Int) => Period.of(i, i, 0)),
+ DayTimeIntervalType() -> ((i: Int) =>
Duration.ofDays(i).plusSeconds(i))
+ ).foreach { case (it, f) =>
+ val data = (1 to 10).map(i => Row(i, f(i)))
+ val schema = StructType(Array(StructField("d", IntegerType, false),
+ StructField("i", it, false)))
+ withTempPath { file =>
+ val df = spark.createDataFrame(sparkContext.parallelize(data),
schema)
+ df.write.orc(file.getCanonicalPath)
+ val df2 = spark.read.orc(file.getCanonicalPath)
+ checkAnswer(df2, df.collect().toSeq)
+ }
+ }
+
+ // Tests for ANSI intervals in complex types.
+ withTempPath { file =>
+ val df = spark.sql(
+ """SELECT
+ | named_struct('interval', interval '1-2' year to month) a,
+ | array(interval '1 2:3' day to minute) b,
+ | map('key', interval '10' year) c,
+ | map(interval '20' second, 'value') d""".stripMargin)
+ df.write.orc(file.getCanonicalPath)
+ val df2 = spark.read.orc(file.getCanonicalPath)
+ checkAnswer(df2, df.collect().toSeq)
+ }
+ }
+ }
+ }
+ }
}
class OrcSourceV1Suite extends OrcSourceSuite {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]