This is an automated email from the ASF dual-hosted git repository.
changchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new ff78343ff9 [GLUTEN-7615][CORE] Introduce `GlutenFormatFactory` (#7616)
ff78343ff9 is described below
commit ff78343ff95e0f16f719c20d3ec1cc50e949d9d2
Author: Chang chen <[email protected]>
AuthorDate: Mon Oct 21 14:09:48 2024 +0800
[GLUTEN-7615][CORE] Introduce `GlutenFormatFactory` (#7616)
---
.../delta/ClickhouseOptimisticTransaction.scala | 8 +-
.../source/DeltaMergeTreeFileFormat.scala | 13 +--
.../delta/ClickhouseOptimisticTransaction.scala | 8 +-
.../source/DeltaMergeTreeFileFormat.scala | 14 +--
.../delta/ClickhouseOptimisticTransaction.scala | 8 +-
.../source/DeltaMergeTreeFileFormat.scala | 13 +--
.../backendsapi/clickhouse/CHListenerApi.scala | 15 +--
.../v1/GlutenMergeTreeWriterInjects.scala | 33 -------
.../clickhouse/MergeTreeFileFormatDataWriter.scala | 4 +-
.../v1/clickhouse/MergeTreeFileFormatWriter.scala | 5 +-
...lutenClickHouseWholeStageTransformerSuite.scala | 3 +-
.../execution/tpch/GlutenClickHouseHDFSSuite.scala | 2 +-
.../backendsapi/velox/VeloxListenerApi.scala | 10 +-
.../GlutenFormatWriterInjectsBase.scala | 7 --
.../datasources/GlutenWriterColumnarRules.scala | 110 +++++++++------------
.../datasource/GlutenFormatWriterInjects.scala | 43 ++++++--
.../datasource/GlutenOrcWriterInjects.scala | 32 ------
.../datasource/GlutenParquetWriterInjects.scala | 31 ------
.../gluten/sql/shims/spark32/Spark32Shims.scala | 6 +-
.../datasources/FileFormatDataWriter.scala | 4 +-
.../execution/datasources/FileFormatWriter.scala | 11 +--
.../execution/datasources/orc/OrcFileFormat.scala | 13 +--
.../datasources/parquet/ParquetFileFormat.scala | 16 ++-
.../spark/sql/hive/execution/HiveFileFormat.scala | 22 +----
.../gluten/sql/shims/spark33/Spark33Shims.scala | 8 +-
.../datasources/FileFormatDataWriter.scala | 4 +-
.../execution/datasources/FileFormatWriter.scala | 11 +--
.../execution/datasources/orc/OrcFileFormat.scala | 14 +--
.../datasources/parquet/ParquetFileFormat.scala | 16 ++-
.../spark/sql/hive/execution/HiveFileFormat.scala | 22 +----
30 files changed, 187 insertions(+), 319 deletions(-)
diff --git
a/backends-clickhouse/src/main/delta-20/org/apache/spark/sql/delta/ClickhouseOptimisticTransaction.scala
b/backends-clickhouse/src/main/delta-20/org/apache/spark/sql/delta/ClickhouseOptimisticTransaction.scala
index 773cd35e93..461c088be9 100644
---
a/backends-clickhouse/src/main/delta-20/org/apache/spark/sql/delta/ClickhouseOptimisticTransaction.scala
+++
b/backends-clickhouse/src/main/delta-20/org/apache/spark/sql/delta/ClickhouseOptimisticTransaction.scala
@@ -29,7 +29,7 @@ import
org.apache.spark.sql.delta.schema.InvariantViolationException
import org.apache.spark.sql.delta.sources.DeltaSQLConf
import org.apache.spark.sql.execution.{SparkPlan, SQLExecution}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec
-import org.apache.spark.sql.execution.datasources.{BasicWriteJobStatsTracker,
FakeRowAdaptor, FileFormatWriter, WriteJobStatsTracker}
+import org.apache.spark.sql.execution.datasources.{BasicWriteJobStatsTracker,
FakeRowAdaptor, FileFormatWriter, GlutenWriterColumnarRules,
WriteJobStatsTracker}
import
org.apache.spark.sql.execution.datasources.v1.clickhouse.MergeTreeFileFormatWriter
import
org.apache.spark.sql.execution.datasources.v2.clickhouse.ClickHouseConfig
import org.apache.spark.util.{Clock, SerializableConfiguration}
@@ -137,10 +137,12 @@ class ClickhouseOptimisticTransaction(
try {
val tableV2 = ClickHouseTableV2.getTable(deltaLog)
+ val format = tableV2.getFileFormat(metadata)
+ GlutenWriterColumnarRules.injectSparkLocalProperty(spark,
Some(format.shortName()))
MergeTreeFileFormatWriter.write(
sparkSession = spark,
plan = newQueryPlan,
- fileFormat = tableV2.getFileFormat(metadata),
+ fileFormat = format,
// formats.
committer = committer,
outputSpec = outputSpec,
@@ -169,6 +171,8 @@ class ClickhouseOptimisticTransaction(
} else {
throw s
}
+ } finally {
+ GlutenWriterColumnarRules.injectSparkLocalProperty(spark, None)
}
}
committer.addedStatuses.toSeq ++ committer.changeFiles
diff --git
a/backends-clickhouse/src/main/delta-20/org/apache/spark/sql/execution/datasources/v2/clickhouse/source/DeltaMergeTreeFileFormat.scala
b/backends-clickhouse/src/main/delta-20/org/apache/spark/sql/execution/datasources/v2/clickhouse/source/DeltaMergeTreeFileFormat.scala
index 5b2dc164b5..94891d0dd4 100644
---
a/backends-clickhouse/src/main/delta-20/org/apache/spark/sql/execution/datasources/v2/clickhouse/source/DeltaMergeTreeFileFormat.scala
+++
b/backends-clickhouse/src/main/delta-20/org/apache/spark/sql/execution/datasources/v2/clickhouse/source/DeltaMergeTreeFileFormat.scala
@@ -16,12 +16,13 @@
*/
package org.apache.spark.sql.execution.datasources.v2.clickhouse.source
+import org.apache.gluten.execution.datasource.GlutenFormatFactory
+
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.delta.DeltaParquetFileFormat
import org.apache.spark.sql.delta.actions.Metadata
import org.apache.spark.sql.execution.datasources.{OutputWriter,
OutputWriterFactory}
import org.apache.spark.sql.execution.datasources.mergetree.DeltaMetaReader
-import
org.apache.spark.sql.execution.datasources.v1.GlutenMergeTreeWriterInjects
import org.apache.spark.sql.types.StructType
import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext}
@@ -31,7 +32,7 @@ class DeltaMergeTreeFileFormat(metadata: Metadata)
override def shortName(): String = "mergetree"
- override def toString(): String = "MergeTree"
+ override def toString: String = "MergeTree"
override def equals(other: Any): Boolean = {
other match {
@@ -51,10 +52,7 @@ class DeltaMergeTreeFileFormat(metadata: Metadata)
// pass compression to job conf so that the file extension can be aware of
it.
val conf = job.getConfiguration
- val nativeConf =
- GlutenMergeTreeWriterInjects
- .getInstance()
- .nativeConf(options, "")
+ val nativeConf = GlutenFormatFactory(shortName()).nativeConf(options, "")
@transient val deltaMetaReader = DeltaMetaReader(metadata)
deltaMetaReader.storageConf.foreach { case (k, v) => conf.set(k, v) }
@@ -69,8 +67,7 @@ class DeltaMergeTreeFileFormat(metadata: Metadata)
dataSchema: StructType,
context: TaskAttemptContext): OutputWriter = {
- GlutenMergeTreeWriterInjects
- .getInstance()
+ GlutenFormatFactory(shortName())
.createOutputWriter(path, metadata.schema, context, nativeConf)
}
}
diff --git
a/backends-clickhouse/src/main/delta-23/org/apache/spark/sql/delta/ClickhouseOptimisticTransaction.scala
b/backends-clickhouse/src/main/delta-23/org/apache/spark/sql/delta/ClickhouseOptimisticTransaction.scala
index 773cd35e93..461c088be9 100644
---
a/backends-clickhouse/src/main/delta-23/org/apache/spark/sql/delta/ClickhouseOptimisticTransaction.scala
+++
b/backends-clickhouse/src/main/delta-23/org/apache/spark/sql/delta/ClickhouseOptimisticTransaction.scala
@@ -29,7 +29,7 @@ import
org.apache.spark.sql.delta.schema.InvariantViolationException
import org.apache.spark.sql.delta.sources.DeltaSQLConf
import org.apache.spark.sql.execution.{SparkPlan, SQLExecution}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec
-import org.apache.spark.sql.execution.datasources.{BasicWriteJobStatsTracker,
FakeRowAdaptor, FileFormatWriter, WriteJobStatsTracker}
+import org.apache.spark.sql.execution.datasources.{BasicWriteJobStatsTracker,
FakeRowAdaptor, FileFormatWriter, GlutenWriterColumnarRules,
WriteJobStatsTracker}
import
org.apache.spark.sql.execution.datasources.v1.clickhouse.MergeTreeFileFormatWriter
import
org.apache.spark.sql.execution.datasources.v2.clickhouse.ClickHouseConfig
import org.apache.spark.util.{Clock, SerializableConfiguration}
@@ -137,10 +137,12 @@ class ClickhouseOptimisticTransaction(
try {
val tableV2 = ClickHouseTableV2.getTable(deltaLog)
+ val format = tableV2.getFileFormat(metadata)
+ GlutenWriterColumnarRules.injectSparkLocalProperty(spark,
Some(format.shortName()))
MergeTreeFileFormatWriter.write(
sparkSession = spark,
plan = newQueryPlan,
- fileFormat = tableV2.getFileFormat(metadata),
+ fileFormat = format,
// formats.
committer = committer,
outputSpec = outputSpec,
@@ -169,6 +171,8 @@ class ClickhouseOptimisticTransaction(
} else {
throw s
}
+ } finally {
+ GlutenWriterColumnarRules.injectSparkLocalProperty(spark, None)
}
}
committer.addedStatuses.toSeq ++ committer.changeFiles
diff --git
a/backends-clickhouse/src/main/delta-23/org/apache/spark/sql/execution/datasources/v2/clickhouse/source/DeltaMergeTreeFileFormat.scala
b/backends-clickhouse/src/main/delta-23/org/apache/spark/sql/execution/datasources/v2/clickhouse/source/DeltaMergeTreeFileFormat.scala
index 24dbc6e03b..6dac603e05 100644
---
a/backends-clickhouse/src/main/delta-23/org/apache/spark/sql/execution/datasources/v2/clickhouse/source/DeltaMergeTreeFileFormat.scala
+++
b/backends-clickhouse/src/main/delta-23/org/apache/spark/sql/execution/datasources/v2/clickhouse/source/DeltaMergeTreeFileFormat.scala
@@ -16,12 +16,13 @@
*/
package org.apache.spark.sql.execution.datasources.v2.clickhouse.source
+import org.apache.gluten.execution.datasource.GlutenFormatFactory
+
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.delta.DeltaParquetFileFormat
import org.apache.spark.sql.delta.actions.Metadata
import org.apache.spark.sql.execution.datasources.{OutputWriter,
OutputWriterFactory}
import org.apache.spark.sql.execution.datasources.mergetree.DeltaMetaReader
-import
org.apache.spark.sql.execution.datasources.v1.GlutenMergeTreeWriterInjects
import org.apache.spark.sql.types.StructType
import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext}
@@ -31,7 +32,7 @@ class DeltaMergeTreeFileFormat(metadata: Metadata) extends
DeltaParquetFileForma
override def shortName(): String = "mergetree"
- override def toString(): String = "MergeTree"
+ override def toString: String = "MergeTree"
override def equals(other: Any): Boolean = {
other match {
@@ -54,11 +55,7 @@ class DeltaMergeTreeFileFormat(metadata: Metadata) extends
DeltaParquetFileForma
// pass compression to job conf so that the file extension can be aware of
it.
val conf = job.getConfiguration
- // just for the sake of compatibility
- val nativeConf =
- GlutenMergeTreeWriterInjects
- .getInstance()
- .nativeConf(options, "")
+ val nativeConf = GlutenFormatFactory(shortName()).nativeConf(options, "")
@transient val deltaMetaReader = DeltaMetaReader(metadata)
deltaMetaReader.storageConf.foreach { case (k, v) => conf.set(k, v) }
@@ -73,8 +70,7 @@ class DeltaMergeTreeFileFormat(metadata: Metadata) extends
DeltaParquetFileForma
dataSchema: StructType,
context: TaskAttemptContext): OutputWriter = {
- GlutenMergeTreeWriterInjects
- .getInstance()
+ GlutenFormatFactory(shortName())
.createOutputWriter(path, metadata.schema, context, nativeConf)
}
}
diff --git
a/backends-clickhouse/src/main/delta-32/org/apache/spark/sql/delta/ClickhouseOptimisticTransaction.scala
b/backends-clickhouse/src/main/delta-32/org/apache/spark/sql/delta/ClickhouseOptimisticTransaction.scala
index 00940a4851..aa1a5006ee 100644
---
a/backends-clickhouse/src/main/delta-32/org/apache/spark/sql/delta/ClickhouseOptimisticTransaction.scala
+++
b/backends-clickhouse/src/main/delta-32/org/apache/spark/sql/delta/ClickhouseOptimisticTransaction.scala
@@ -32,7 +32,7 @@ import org.apache.spark.sql.delta.sources.DeltaSQLConf
import org.apache.spark.sql.delta.stats.DeltaJobStatisticsTracker
import org.apache.spark.sql.execution.{CHDelayedCommitProtocol,
QueryExecution, SparkPlan, SQLExecution}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec
-import org.apache.spark.sql.execution.datasources.{BasicWriteJobStatsTracker,
FakeRowAdaptor, FileFormatWriter, WriteFiles, WriteJobStatsTracker}
+import org.apache.spark.sql.execution.datasources.{BasicWriteJobStatsTracker,
FakeRowAdaptor, FileFormatWriter, GlutenWriterColumnarRules, WriteFiles,
WriteJobStatsTracker}
import
org.apache.spark.sql.execution.datasources.v1.clickhouse.MergeTreeFileFormatWriter
import
org.apache.spark.sql.execution.datasources.v2.clickhouse.ClickHouseConfig
import org.apache.spark.sql.internal.SQLConf
@@ -149,10 +149,12 @@ class ClickhouseOptimisticTransaction(
try {
val tableV2 = ClickHouseTableV2.getTable(deltaLog)
+ val format = tableV2.getFileFormat(protocol, metadata)
+ GlutenWriterColumnarRules.injectSparkLocalProperty(spark,
Some(format.shortName()))
MergeTreeFileFormatWriter.write(
sparkSession = spark,
plan = newQueryPlan,
- fileFormat = tableV2.getFileFormat(protocol, metadata),
+ fileFormat = format,
// formats.
committer = committer,
outputSpec = outputSpec,
@@ -181,6 +183,8 @@ class ClickhouseOptimisticTransaction(
} else {
throw s
}
+ } finally {
+ GlutenWriterColumnarRules.injectSparkLocalProperty(spark, None)
}
}
committer.addedStatuses.toSeq ++ committer.changeFiles
diff --git
a/backends-clickhouse/src/main/delta-32/org/apache/spark/sql/execution/datasources/v2/clickhouse/source/DeltaMergeTreeFileFormat.scala
b/backends-clickhouse/src/main/delta-32/org/apache/spark/sql/execution/datasources/v2/clickhouse/source/DeltaMergeTreeFileFormat.scala
index 6cc431f4f9..3c4d2eb2e2 100644
---
a/backends-clickhouse/src/main/delta-32/org/apache/spark/sql/execution/datasources/v2/clickhouse/source/DeltaMergeTreeFileFormat.scala
+++
b/backends-clickhouse/src/main/delta-32/org/apache/spark/sql/execution/datasources/v2/clickhouse/source/DeltaMergeTreeFileFormat.scala
@@ -16,12 +16,13 @@
*/
package org.apache.spark.sql.execution.datasources.v2.clickhouse.source
+import org.apache.gluten.execution.datasource.GlutenFormatFactory
+
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.delta.DeltaParquetFileFormat
import org.apache.spark.sql.delta.actions.{Metadata, Protocol}
import org.apache.spark.sql.execution.datasources.{OutputWriter,
OutputWriterFactory}
import org.apache.spark.sql.execution.datasources.mergetree.DeltaMetaReader
-import
org.apache.spark.sql.execution.datasources.v1.GlutenMergeTreeWriterInjects
import org.apache.spark.sql.types.StructType
import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext}
@@ -32,7 +33,7 @@ class DeltaMergeTreeFileFormat(protocol: Protocol, metadata:
Metadata)
override def shortName(): String = "mergetree"
- override def toString(): String = "MergeTree"
+ override def toString: String = "MergeTree"
override def equals(other: Any): Boolean = {
other match {
@@ -54,10 +55,7 @@ class DeltaMergeTreeFileFormat(protocol: Protocol, metadata:
Metadata)
// pass compression to job conf so that the file extension can be aware of
it.
val conf = job.getConfiguration
- val nativeConf =
- GlutenMergeTreeWriterInjects
- .getInstance()
- .nativeConf(options, "")
+ val nativeConf = GlutenFormatFactory(shortName()).nativeConf(options, "")
@transient val deltaMetaReader = DeltaMetaReader(metadata)
deltaMetaReader.storageConf.foreach { case (k, v) => conf.set(k, v) }
@@ -72,8 +70,7 @@ class DeltaMergeTreeFileFormat(protocol: Protocol, metadata:
Metadata)
dataSchema: StructType,
context: TaskAttemptContext): OutputWriter = {
- GlutenMergeTreeWriterInjects
- .getInstance()
+ GlutenFormatFactory(shortName())
.createOutputWriter(path, metadata.schema, context, nativeConf)
}
}
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHListenerApi.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHListenerApi.scala
index 16f5fa064c..6ae957912a 100644
---
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHListenerApi.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHListenerApi.scala
@@ -20,7 +20,7 @@ import org.apache.gluten.GlutenConfig
import org.apache.gluten.backendsapi.ListenerApi
import org.apache.gluten.columnarbatch.CHBatch
import org.apache.gluten.execution.CHBroadcastBuildSideCache
-import org.apache.gluten.execution.datasource.{GlutenOrcWriterInjects,
GlutenParquetWriterInjects, GlutenRowSplitter}
+import org.apache.gluten.execution.datasource.GlutenFormatFactory
import org.apache.gluten.expression.UDFMappings
import org.apache.gluten.extension.ExpressionExtensionTrait
import org.apache.gluten.jni.JniLibLoader
@@ -32,6 +32,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.listener.CHGlutenSQLAppStatusListener
import org.apache.spark.network.util.JavaUtils
import org.apache.spark.rpc.{GlutenDriverEndpoint, GlutenExecutorEndpoint}
+import org.apache.spark.sql.execution.datasources.GlutenWriterColumnarRules
import org.apache.spark.sql.execution.datasources.v1._
import org.apache.spark.sql.utils.ExpressionUtil
import org.apache.spark.util.SparkDirectoryUtil
@@ -109,11 +110,13 @@ class CHListenerApi extends ListenerApi with Logging {
CHNativeExpressionEvaluator.initNative(conf.getAll.toMap)
// inject backend-specific implementations to override spark classes
- // FIXME: The following set instances twice in local mode?
- GlutenParquetWriterInjects.setInstance(new CHParquetWriterInjects())
- GlutenOrcWriterInjects.setInstance(new CHOrcWriterInjects())
- GlutenMergeTreeWriterInjects.setInstance(new CHMergeTreeWriterInjects())
- GlutenRowSplitter.setInstance(new CHRowSplitter())
+ GlutenFormatFactory.register(
+ new CHParquetWriterInjects,
+ new CHOrcWriterInjects,
+ new CHMergeTreeWriterInjects)
+ GlutenFormatFactory.injectPostRuleFactory(
+ session => GlutenWriterColumnarRules.NativeWritePostRule(session))
+ GlutenFormatFactory.register(new CHRowSplitter())
}
private def shutdown(): Unit = {
diff --git
a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v1/GlutenMergeTreeWriterInjects.scala
b/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v1/GlutenMergeTreeWriterInjects.scala
deleted file mode 100644
index 36d8481b1a..0000000000
---
a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v1/GlutenMergeTreeWriterInjects.scala
+++ /dev/null
@@ -1,33 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.spark.sql.execution.datasources.v1
-
-import org.apache.gluten.execution.datasource.GlutenFormatWriterInjects
-
-object GlutenMergeTreeWriterInjects {
- private var INSTANCE: GlutenFormatWriterInjects = _
-
- def setInstance(instance: GlutenFormatWriterInjects): Unit = {
- INSTANCE = instance
- }
- def getInstance(): GlutenFormatWriterInjects = {
- if (INSTANCE == null) {
- throw new IllegalStateException("GlutenOutputWriterFactoryCreator is not
initialized")
- }
- INSTANCE
- }
-}
diff --git
a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v1/clickhouse/MergeTreeFileFormatDataWriter.scala
b/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v1/clickhouse/MergeTreeFileFormatDataWriter.scala
index 29f2b7e16e..8be183f051 100644
---
a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v1/clickhouse/MergeTreeFileFormatDataWriter.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v1/clickhouse/MergeTreeFileFormatDataWriter.scala
@@ -16,7 +16,7 @@
*/
package org.apache.spark.sql.execution.datasources.v1.clickhouse
-import org.apache.gluten.execution.datasource.GlutenRowSplitter
+import org.apache.gluten.execution.datasource.GlutenFormatFactory
import org.apache.spark.internal.Logging
import org.apache.spark.internal.io.{FileCommitProtocol, FileNameSpec}
@@ -437,7 +437,7 @@ class MergeTreeDynamicPartitionDataSingleWriter(
record match {
case fakeRow: FakeRow =>
if (fakeRow.batch.numRows() > 0) {
- val blockStripes = GlutenRowSplitter.getInstance
+ val blockStripes = GlutenFormatFactory.rowSplitter
.splitBlockByPartitionAndBucket(
fakeRow,
partitionColIndice,
diff --git
a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v1/clickhouse/MergeTreeFileFormatWriter.scala
b/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v1/clickhouse/MergeTreeFileFormatWriter.scala
index f1489b86b3..7c3da47e66 100644
---
a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v1/clickhouse/MergeTreeFileFormatWriter.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v1/clickhouse/MergeTreeFileFormatWriter.scala
@@ -16,6 +16,7 @@
*/
package org.apache.spark.sql.execution.datasources.v1.clickhouse
+import org.apache.gluten.execution.datasource.GlutenFormatFactory
import org.apache.gluten.memory.CHThreadGroup
import org.apache.spark.{SparkException, TaskContext,
TaskOutputFileAlreadyExistException}
@@ -33,7 +34,6 @@ import org.apache.spark.sql.delta.constraints.Constraint
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.datasources._
import
org.apache.spark.sql.execution.datasources.FileFormatWriter.{processStats,
ConcurrentOutputWriterSpec, OutputSpec}
-import
org.apache.spark.sql.execution.datasources.v1.GlutenMergeTreeWriterInjects
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.util.{SerializableConfiguration, Utils}
@@ -173,7 +173,8 @@ object MergeTreeFileFormatWriter extends Logging {
// TODO: to optimize, bucket value is computed twice here
}
-
(GlutenMergeTreeWriterInjects.getInstance().executeWriterWrappedSparkPlan(wrapped),
None)
+ val nativeFormat =
sparkSession.sparkContext.getLocalProperty("nativeFormat")
+
(GlutenFormatFactory(nativeFormat).executeWriterWrappedSparkPlan(wrapped), None)
}
try {
diff --git
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseWholeStageTransformerSuite.scala
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseWholeStageTransformerSuite.scala
index e5c4d14a34..bdf5a3a5ae 100644
---
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseWholeStageTransformerSuite.scala
+++
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseWholeStageTransformerSuite.scala
@@ -45,7 +45,8 @@ class GlutenClickHouseWholeStageTransformerSuite extends
WholeStageTransformerSu
val S3_CACHE_PATH = s"/tmp/s3_cache/$sparkVersion/"
val S3_ENDPOINT = "s3://127.0.0.1:9000/"
val MINIO_ENDPOINT: String = S3_ENDPOINT.replace("s3", "http")
- val BUCKET_NAME: String = sparkVersion.replace(".", "-")
+ val SPARK_DIR_NAME: String = sparkVersion.replace(".", "-")
+ val BUCKET_NAME: String = SPARK_DIR_NAME
val WHOLE_PATH: String = MINIO_ENDPOINT + BUCKET_NAME + "/"
val HDFS_METADATA_PATH = s"/tmp/metadata/hdfs/$sparkVersion/"
diff --git
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseHDFSSuite.scala
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseHDFSSuite.scala
index 7a1bba0f1b..b1fbd3e736 100644
---
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseHDFSSuite.scala
+++
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseHDFSSuite.scala
@@ -179,7 +179,7 @@ class GlutenClickHouseHDFSSuite
test("test set_read_util_position") {
val tableName = "read_until_test"
- val tablePath = s"$tablesPath/$tableName/"
+ val tablePath = s"$tablesPath/$SPARK_DIR_NAME/$tableName/"
val targetFile = new Path(tablesPath)
val fs = targetFile.getFileSystem(spark.sessionState.newHadoopConf())
fs.delete(new Path(tablePath), true)
diff --git
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxListenerApi.scala
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxListenerApi.scala
index e763e31dc5..850509db3e 100644
---
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxListenerApi.scala
+++
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxListenerApi.scala
@@ -20,7 +20,7 @@ import org.apache.gluten.GlutenConfig
import org.apache.gluten.backendsapi.ListenerApi
import org.apache.gluten.columnarbatch.ArrowBatches.{ArrowJavaBatch,
ArrowNativeBatch}
import org.apache.gluten.columnarbatch.VeloxBatch
-import org.apache.gluten.execution.datasource.{GlutenOrcWriterInjects,
GlutenParquetWriterInjects, GlutenRowSplitter}
+import org.apache.gluten.execution.datasource.GlutenFormatFactory
import org.apache.gluten.expression.UDFMappings
import org.apache.gluten.init.NativeBackendInitializer
import org.apache.gluten.jni.{JniLibLoader, JniWorkspace}
@@ -31,6 +31,7 @@ import org.apache.spark.{HdfsConfGenerator, SparkConf,
SparkContext}
import org.apache.spark.api.plugin.PluginContext
import org.apache.spark.internal.Logging
import org.apache.spark.network.util.ByteUnit
+import org.apache.spark.sql.execution.datasources.GlutenWriterColumnarRules
import
org.apache.spark.sql.execution.datasources.velox.{VeloxOrcWriterInjects,
VeloxParquetWriterInjects, VeloxRowSplitter}
import org.apache.spark.sql.expression.UDFResolver
import org.apache.spark.sql.internal.{GlutenConfigUtil, StaticSQLConf}
@@ -160,9 +161,10 @@ class VeloxListenerApi extends ListenerApi with Logging {
NativeBackendInitializer.initializeBackend(parsed)
// Inject backend-specific implementations to override spark classes.
- GlutenParquetWriterInjects.setInstance(new VeloxParquetWriterInjects())
- GlutenOrcWriterInjects.setInstance(new VeloxOrcWriterInjects())
- GlutenRowSplitter.setInstance(new VeloxRowSplitter())
+ GlutenFormatFactory.register(new VeloxParquetWriterInjects, new
VeloxOrcWriterInjects)
+ GlutenFormatFactory.injectPostRuleFactory(
+ session => GlutenWriterColumnarRules.NativeWritePostRule(session))
+ GlutenFormatFactory.register(new VeloxRowSplitter())
}
private def shutdown(): Unit = {
diff --git
a/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/datasources/GlutenFormatWriterInjectsBase.scala
b/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/datasources/GlutenFormatWriterInjectsBase.scala
index 450b88163a..9ec75aa209 100644
---
a/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/datasources/GlutenFormatWriterInjectsBase.scala
+++
b/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/datasources/GlutenFormatWriterInjectsBase.scala
@@ -23,12 +23,9 @@ import
org.apache.gluten.extension.columnar.MiscColumnarRules.TransformPreOverri
import
org.apache.gluten.extension.columnar.rewrite.RewriteSparkPlanRulesManager
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.{ColumnarCollapseTransformStages,
SparkPlan}
import
org.apache.spark.sql.execution.ColumnarCollapseTransformStages.transformStageCounter
-import
org.apache.spark.sql.execution.datasources.GlutenWriterColumnarRules.NativeWritePostRule
trait GlutenFormatWriterInjectsBase extends GlutenFormatWriterInjects {
@@ -75,8 +72,4 @@ trait GlutenFormatWriterInjectsBase extends
GlutenFormatWriterInjects {
transformStageCounter.incrementAndGet())
FakeRowAdaptor(wst).execute()
}
-
- override def getExtendedColumnarPostRule(session: SparkSession):
Rule[SparkPlan] = {
- NativeWritePostRule(session)
- }
}
diff --git
a/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/datasources/GlutenWriterColumnarRules.scala
b/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/datasources/GlutenWriterColumnarRules.scala
index 20b0060153..d33e779eb3 100644
---
a/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/datasources/GlutenWriterColumnarRules.scala
+++
b/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/datasources/GlutenWriterColumnarRules.scala
@@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.datasources
import org.apache.gluten.backendsapi.BackendsApiManager
import org.apache.gluten.execution.ColumnarToRowExecBase
+import org.apache.gluten.execution.datasource.GlutenFormatFactory
import org.apache.gluten.extension.GlutenPlan
import org.apache.gluten.extension.columnar.transition.Transitions
@@ -30,10 +31,9 @@ import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec
import
org.apache.spark.sql.execution.command.{CreateDataSourceTableAsSelectCommand,
DataWritingCommand, DataWritingCommandExec}
-import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat
-import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
import org.apache.spark.sql.execution.datasources.v2.{AppendDataExec,
OverwriteByExpressionExec}
import org.apache.spark.sql.hive.execution.{CreateHiveTableAsSelectCommand,
InsertIntoHiveDirCommand, InsertIntoHiveTable}
+import org.apache.spark.sql.sources.DataSourceRegister
import org.apache.spark.sql.vectorized.ColumnarBatch
private case class FakeRowLogicAdaptor(child: LogicalPlan) extends
OrderPreservingUnaryNode {
@@ -92,56 +92,35 @@ object GlutenWriterColumnarRules {
// 1. pull out `Empty2Null` and required ordering to `WriteFilesExec`, see
Spark3.4 `V1Writes`
// 2. support detect partition value, partition path, bucket value, bucket
path at native side,
// see `BaseDynamicPartitionDataWriter`
- def getNativeFormat(cmd: DataWritingCommand): Option[String] = {
- val parquetHiveFormat =
"org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat"
- val orcHiveFormat = "org.apache.hadoop.hive.ql.io.orc.OrcOutputFormat"
-
+ private val formatMapping = Map(
+ "org.apache.hadoop.hive.ql.io.orc.OrcOutputFormat" -> "orc",
+ "org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat" ->
"parquet"
+ )
+ private def getNativeFormat(cmd: DataWritingCommand): Option[String] = {
if (!BackendsApiManager.getSettings.enableNativeWriteFiles()) {
return None
}
cmd match {
- case command: CreateDataSourceTableAsSelectCommand =>
- if (BackendsApiManager.getSettings.skipNativeCtas(command)) {
- return None
- }
- if ("parquet".equals(command.table.provider.get)) {
- Some("parquet")
- } else if ("orc".equals(command.table.provider.get)) {
- Some("orc")
- } else {
- None
- }
+ case command: CreateDataSourceTableAsSelectCommand
+ if !BackendsApiManager.getSettings.skipNativeCtas(command) =>
+ command.table.provider.filter(GlutenFormatFactory.isRegistered)
case command: InsertIntoHadoopFsRelationCommand
- if command.fileFormat.isInstanceOf[ParquetFileFormat] ||
- command.fileFormat.isInstanceOf[OrcFileFormat] =>
- if (BackendsApiManager.getSettings.skipNativeInsertInto(command)) {
- return None
- }
-
- if (command.fileFormat.isInstanceOf[ParquetFileFormat]) {
- Some("parquet")
- } else if (command.fileFormat.isInstanceOf[OrcFileFormat]) {
- Some("orc")
- } else {
- None
+ if !BackendsApiManager.getSettings.skipNativeInsertInto(command) =>
+ command.fileFormat match {
+ case register: DataSourceRegister
+ if GlutenFormatFactory.isRegistered(register.shortName()) =>
+ Some(register.shortName())
+ case _ => None
}
case command: InsertIntoHiveDirCommand =>
- if (command.storage.outputFormat.get.equals(parquetHiveFormat)) {
- Some("parquet")
- } else if (command.storage.outputFormat.get.equals(orcHiveFormat)) {
- Some("orc")
- } else {
- None
- }
+ command.storage.outputFormat
+ .flatMap(formatMapping.get)
+ .filter(GlutenFormatFactory.isRegistered)
case command: InsertIntoHiveTable =>
- if (command.table.storage.outputFormat.get.equals(parquetHiveFormat)) {
- Some("parquet")
- } else if
(command.table.storage.outputFormat.get.equals(orcHiveFormat)) {
- Some("orc")
- } else {
- None
- }
+ command.table.storage.outputFormat
+ .flatMap(formatMapping.get)
+ .filter(GlutenFormatFactory.isRegistered)
case _: CreateHiveTableAsSelectCommand =>
None
case _ =>
@@ -163,27 +142,22 @@ object GlutenWriterColumnarRules {
BackendsApiManager.getSettings.enableNativeWriteFiles() =>
injectFakeRowAdaptor(rc, rc.child)
case rc @ DataWritingCommandExec(cmd, child) =>
- // These properties can be set by the same thread in last query
submission.
- session.sparkContext.setLocalProperty("isNativeApplicable", null)
- session.sparkContext.setLocalProperty("nativeFormat", null)
- session.sparkContext.setLocalProperty("staticPartitionWriteOnly", null)
- if
(BackendsApiManager.getSettings.supportNativeWrite(child.output.toStructType.fields))
{
- val format = getNativeFormat(cmd)
- session.sparkContext.setLocalProperty(
- "staticPartitionWriteOnly",
- BackendsApiManager.getSettings.staticPartitionWriteOnly().toString)
- // FIXME: We should only use context property if having no other
approaches.
- // Should see if there is another way to pass these options.
- session.sparkContext.setLocalProperty("isNativeApplicable",
format.isDefined.toString)
- session.sparkContext.setLocalProperty("nativeFormat",
format.getOrElse(""))
- if (format.isDefined) {
- injectFakeRowAdaptor(rc, child)
+ // The same thread can set these properties in the last query
submission.
+ val fields = child.output.toStructType.fields
+ val format =
+ if (BackendsApiManager.getSettings.supportNativeWrite(fields)) {
+ getNativeFormat(cmd)
} else {
- rc.withNewChildren(rc.children.map(apply))
+ None
}
- } else {
- rc.withNewChildren(rc.children.map(apply))
+ injectSparkLocalProperty(session, format)
+ format match {
+ case Some(_) =>
+ injectFakeRowAdaptor(rc, child)
+ case None =>
+ rc.withNewChildren(rc.children.map(apply))
}
+
case plan: SparkPlan => plan.withNewChildren(plan.children.map(apply))
}
@@ -211,4 +185,18 @@ object GlutenWriterColumnarRules {
}
}
}
+
+ def injectSparkLocalProperty(spark: SparkSession, format: Option[String]):
Unit = {
+ if (format.isDefined) {
+ spark.sparkContext.setLocalProperty("isNativeApplicable", true.toString)
+ spark.sparkContext.setLocalProperty("nativeFormat", format.get)
+ spark.sparkContext.setLocalProperty(
+ "staticPartitionWriteOnly",
+ BackendsApiManager.getSettings.staticPartitionWriteOnly().toString)
+ } else {
+ spark.sparkContext.setLocalProperty("isNativeApplicable", null)
+ spark.sparkContext.setLocalProperty("nativeFormat", null)
+ spark.sparkContext.setLocalProperty("staticPartitionWriteOnly", null)
+ }
+ }
}
diff --git
a/shims/common/src/main/scala/org/apache/gluten/execution/datasource/GlutenFormatWriterInjects.scala
b/shims/common/src/main/scala/org/apache/gluten/execution/datasource/GlutenFormatWriterInjects.scala
index 49d86ae380..0221899a3d 100644
---
a/shims/common/src/main/scala/org/apache/gluten/execution/datasource/GlutenFormatWriterInjects.scala
+++
b/shims/common/src/main/scala/org/apache/gluten/execution/datasource/GlutenFormatWriterInjects.scala
@@ -46,8 +46,6 @@ trait GlutenFormatWriterInjects {
compressionCodec: String): java.util.Map[String, String]
def formatName: String
-
- def getExtendedColumnarPostRule(session: SparkSession): Rule[SparkPlan]
}
trait GlutenRowSplitter {
@@ -58,17 +56,42 @@ trait GlutenRowSplitter {
reserve_partition_columns: Boolean = false): BlockStripes
}
-object GlutenRowSplitter {
- private var INSTANCE: GlutenRowSplitter = _
+object GlutenFormatFactory {
+ private var instances: Map[String, GlutenFormatWriterInjects] = _
+ private var postRuleFactory: SparkSession => Rule[SparkPlan] = _
+ private var rowSplitterInstance: GlutenRowSplitter = _
+
+ def register(items: GlutenFormatWriterInjects*): Unit = {
+ instances = items.map(item => (item.formatName, item)).toMap
+ }
+
+ def isRegistered(name: String): Boolean = instances.contains(name)
+
+ def apply(name: String): GlutenFormatWriterInjects = {
+ instances.getOrElse(
+ name,
+ throw new IllegalStateException(s"GlutenFormatWriterInjects for $name is
not initialized"))
+ }
+
+ def injectPostRuleFactory(factory: SparkSession => Rule[SparkPlan]): Unit = {
+ postRuleFactory = factory
+ }
+
+ def getExtendedColumnarPostRule(session: SparkSession): Rule[SparkPlan] = {
+ if (postRuleFactory == null) {
+ throw new IllegalStateException("GlutenFormatFactory is not initialized")
+ }
+ postRuleFactory(session)
+ }
- def setInstance(instance: GlutenRowSplitter): Unit = {
- INSTANCE = instance
+ def register(rowSplitter: GlutenRowSplitter): Unit = {
+ rowSplitterInstance = rowSplitter
}
- def getInstance(): GlutenRowSplitter = {
- if (INSTANCE == null) {
- throw new IllegalStateException("GlutenOutputWriterFactoryCreator is not
initialized")
+ def rowSplitter: GlutenRowSplitter = {
+ if (rowSplitterInstance == null) {
+ throw new IllegalStateException("GlutenRowSplitter is not initialized")
}
- INSTANCE
+ rowSplitterInstance
}
}
diff --git
a/shims/common/src/main/scala/org/apache/gluten/execution/datasource/GlutenOrcWriterInjects.scala
b/shims/common/src/main/scala/org/apache/gluten/execution/datasource/GlutenOrcWriterInjects.scala
deleted file mode 100644
index 193cb23a5d..0000000000
---
a/shims/common/src/main/scala/org/apache/gluten/execution/datasource/GlutenOrcWriterInjects.scala
+++ /dev/null
@@ -1,32 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.gluten.execution.datasource
-
-object GlutenOrcWriterInjects {
-
- private var INSTANCE: GlutenFormatWriterInjects = _
-
- def setInstance(instance: GlutenFormatWriterInjects): Unit = {
- INSTANCE = instance
- }
- def getInstance(): GlutenFormatWriterInjects = {
- if (INSTANCE == null) {
- throw new IllegalStateException("GlutenOutputWriterFactoryCreator is not
initialized")
- }
- INSTANCE
- }
-}
diff --git
a/shims/common/src/main/scala/org/apache/gluten/execution/datasource/GlutenParquetWriterInjects.scala
b/shims/common/src/main/scala/org/apache/gluten/execution/datasource/GlutenParquetWriterInjects.scala
deleted file mode 100644
index ffbec6d89c..0000000000
---
a/shims/common/src/main/scala/org/apache/gluten/execution/datasource/GlutenParquetWriterInjects.scala
+++ /dev/null
@@ -1,31 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.gluten.execution.datasource
-
-object GlutenParquetWriterInjects {
- private var INSTANCE: GlutenFormatWriterInjects = _
-
- def setInstance(instance: GlutenFormatWriterInjects): Unit = {
- INSTANCE = instance
- }
- def getInstance(): GlutenFormatWriterInjects = {
- if (INSTANCE == null) {
- throw new IllegalStateException("GlutenOutputWriterFactoryCreator is not
initialized")
- }
- INSTANCE
- }
-}
diff --git
a/shims/spark32/src/main/scala/org/apache/gluten/sql/shims/spark32/Spark32Shims.scala
b/shims/spark32/src/main/scala/org/apache/gluten/sql/shims/spark32/Spark32Shims.scala
index 973e675fa9..833d2385b0 100644
---
a/shims/spark32/src/main/scala/org/apache/gluten/sql/shims/spark32/Spark32Shims.scala
+++
b/shims/spark32/src/main/scala/org/apache/gluten/sql/shims/spark32/Spark32Shims.scala
@@ -16,7 +16,7 @@
*/
package org.apache.gluten.sql.shims.spark32
-import org.apache.gluten.execution.datasource.GlutenParquetWriterInjects
+import org.apache.gluten.execution.datasource.GlutenFormatFactory
import org.apache.gluten.expression.{ExpressionNames, Sig}
import org.apache.gluten.sql.shims.{ShimDescriptor, SparkShims}
@@ -91,7 +91,7 @@ class Spark32Shims extends SparkShims {
options: CaseInsensitiveStringMap,
partitionFilters: Seq[Expression],
dataFilters: Seq[Expression]): TextScan = {
- new TextScan(
+ TextScan(
sparkSession,
fileIndex,
readDataSchema,
@@ -154,7 +154,7 @@ class Spark32Shims extends SparkShims {
mightContainReplacer: (Expression, Expression) => BinaryExpression):
Expression = expr
override def getExtendedColumnarPostRules(): List[SparkSession =>
Rule[SparkPlan]] = {
- List(session =>
GlutenParquetWriterInjects.getInstance().getExtendedColumnarPostRule(session))
+ List(session => GlutenFormatFactory.getExtendedColumnarPostRule(session))
}
override def createTestTaskContext(properties: Properties): TaskContext = {
diff --git
a/shims/spark32/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala
b/shims/spark32/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala
index e5aaff6911..a0dafa25db 100644
---
a/shims/spark32/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala
+++
b/shims/spark32/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala
@@ -23,7 +23,7 @@ package org.apache.spark.sql.execution.datasources
* we can move this class to shims-spark32,
* shims-spark33, etc.
*/
-import org.apache.gluten.execution.datasource.GlutenRowSplitter
+import org.apache.gluten.execution.datasource.GlutenFormatFactory
import org.apache.spark.internal.Logging
import org.apache.spark.internal.io.FileCommitProtocol
@@ -408,7 +408,7 @@ class DynamicPartitionDataSingleWriter(
record match {
case fakeRow: FakeRow =>
if (fakeRow.batch.numRows() > 0) {
- val blockStripes = GlutenRowSplitter.getInstance
+ val blockStripes = GlutenFormatFactory.rowSplitter
.splitBlockByPartitionAndBucket(fakeRow, partitionColIndice,
isBucketed)
val iter = blockStripes.iterator()
diff --git
a/shims/spark32/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
b/shims/spark32/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
index 96a044c0cb..e9fef1767e 100644
---
a/shims/spark32/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
+++
b/shims/spark32/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
@@ -16,8 +16,7 @@
*/
package org.apache.spark.sql.execution.datasources
-import org.apache.gluten.execution.datasource.GlutenOrcWriterInjects
-import org.apache.gluten.execution.datasource.GlutenParquetWriterInjects
+import org.apache.gluten.execution.datasource.GlutenFormatFactory
import org.apache.spark._
import org.apache.spark.internal.Logging
@@ -59,7 +58,7 @@ import java.util.{Date, UUID}
/** A helper object for writing FileFormat data out to a location. */
object FileFormatWriter extends Logging {
- var executeWriterWrappedSparkPlan: SparkPlan => RDD[InternalRow] = null
+ var executeWriterWrappedSparkPlan: SparkPlan => RDD[InternalRow] = _
/** Describes how output files should be placed in the filesystem. */
case class OutputSpec(
@@ -257,11 +256,7 @@ object FileFormatWriter extends Logging {
}
val nativeFormat =
sparkSession.sparkContext.getLocalProperty("nativeFormat")
- if ("parquet" == nativeFormat) {
-
(GlutenParquetWriterInjects.getInstance().executeWriterWrappedSparkPlan(wrapped),
None)
- } else {
-
(GlutenOrcWriterInjects.getInstance().executeWriterWrappedSparkPlan(wrapped),
None)
- }
+
(GlutenFormatFactory(nativeFormat).executeWriterWrappedSparkPlan(wrapped), None)
}
try {
diff --git
a/shims/spark32/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala
b/shims/spark32/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala
index 8ed7614ae0..1d1771ccd1 100644
---
a/shims/spark32/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala
+++
b/shims/spark32/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala
@@ -16,7 +16,7 @@
*/
package org.apache.spark.sql.execution.datasources.orc
-import org.apache.gluten.execution.datasource.GlutenOrcWriterInjects
+import org.apache.gluten.execution.datasource.GlutenFormatFactory
import org.apache.spark.TaskContext
import org.apache.spark.sql.SparkSession
@@ -104,9 +104,7 @@ class OrcFileFormat extends FileFormat with
DataSourceRegister with Serializable
if ("true" ==
sparkSession.sparkContext.getLocalProperty("isNativeApplicable")) {
// pass compression to job conf so that the file extension can be aware
of it.
val nativeConf =
- GlutenOrcWriterInjects
- .getInstance()
- .nativeConf(options, orcOptions.compressionCodec)
+ GlutenFormatFactory(shortName()).nativeConf(options,
orcOptions.compressionCodec)
new OutputWriterFactory {
override def getFileExtension(context: TaskAttemptContext): String = {
@@ -122,9 +120,8 @@ class OrcFileFormat extends FileFormat with
DataSourceRegister with Serializable
path: String,
dataSchema: StructType,
context: TaskAttemptContext): OutputWriter = {
- GlutenOrcWriterInjects
- .getInstance()
- .createOutputWriter(path, dataSchema, context, nativeConf);
+ GlutenFormatFactory(shortName())
+ .createOutputWriter(path, dataSchema, context, nativeConf)
}
}
@@ -188,7 +185,7 @@ class OrcFileFormat extends FileFormat with
DataSourceRegister with Serializable
requiredSchema: StructType,
filters: Seq[Filter],
options: Map[String, String],
- hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] =
{
+ hadoopConf: Configuration): PartitionedFile => Iterator[InternalRow] = {
val resultSchema = StructType(requiredSchema.fields ++
partitionSchema.fields)
val sqlConf = sparkSession.sessionState.conf
diff --git
a/shims/spark32/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala
b/shims/spark32/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala
index 145c36e467..eec4c86fcb 100644
---
a/shims/spark32/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala
+++
b/shims/spark32/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala
@@ -16,7 +16,7 @@
*/
package org.apache.spark.sql.execution.datasources.parquet
-import org.apache.gluten.execution.datasource.GlutenParquetWriterInjects
+import org.apache.gluten.execution.datasource.GlutenFormatFactory
import org.apache.spark.TaskContext
import org.apache.spark.internal.Logging
@@ -89,8 +89,7 @@ class ParquetFileFormat extends FileFormat with
DataSourceRegister with Logging
val parquetOptions = new ParquetOptions(options,
sparkSession.sessionState.conf)
conf.set(ParquetOutputFormat.COMPRESSION,
parquetOptions.compressionCodecClassName)
val nativeConf =
- GlutenParquetWriterInjects
- .getInstance()
+ GlutenFormatFactory(shortName())
.nativeConf(options, parquetOptions.compressionCodecClassName)
new OutputWriterFactory {
@@ -102,9 +101,8 @@ class ParquetFileFormat extends FileFormat with
DataSourceRegister with Logging
path: String,
dataSchema: StructType,
context: TaskAttemptContext): OutputWriter = {
- GlutenParquetWriterInjects
- .getInstance()
- .createOutputWriter(path, dataSchema, context, nativeConf);
+ GlutenFormatFactory(shortName())
+ .createOutputWriter(path, dataSchema, context, nativeConf)
}
}
@@ -238,7 +236,7 @@ class ParquetFileFormat extends FileFormat with
DataSourceRegister with Logging
requiredSchema: StructType,
filters: Seq[Filter],
options: Map[String, String],
- hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] =
{
+ hadoopConf: Configuration): PartitionedFile => Iterator[InternalRow] = {
hadoopConf.set(ParquetInputFormat.READ_SUPPORT_CLASS,
classOf[ParquetReadSupport].getName)
hadoopConf.set(ParquetReadSupport.SPARK_ROW_REQUESTED_SCHEMA,
requiredSchema.json)
hadoopConf.set(ParquetWriteSupport.SPARK_ROW_SCHEMA, requiredSchema.json)
@@ -331,7 +329,7 @@ class ParquetFileFormat extends FileFormat with
DataSourceRegister with Logging
// have different writers.
// Define isCreatedByParquetMr as function to avoid unnecessary parquet
footer reads.
def isCreatedByParquetMr: Boolean =
- footerFileMetaData.getCreatedBy().startsWith("parquet-mr")
+ footerFileMetaData.getCreatedBy.startsWith("parquet-mr")
val convertTz =
if (timestampConversion && !isCreatedByParquetMr) {
@@ -528,7 +526,7 @@ object ParquetFileFormat extends Logging {
// when it can't read the footer.
Some(
new Footer(
- currentFile.getPath(),
+ currentFile.getPath,
ParquetFooterReader.readFooter(conf, currentFile,
SKIP_ROW_GROUPS)))
} catch {
case e: RuntimeException =>
diff --git
a/shims/spark32/src/main/scala/org/apache/spark/sql/hive/execution/HiveFileFormat.scala
b/shims/spark32/src/main/scala/org/apache/spark/sql/hive/execution/HiveFileFormat.scala
index eb0f6a5d97..c21c67f654 100644
---
a/shims/spark32/src/main/scala/org/apache/spark/sql/hive/execution/HiveFileFormat.scala
+++
b/shims/spark32/src/main/scala/org/apache/spark/sql/hive/execution/HiveFileFormat.scala
@@ -16,8 +16,7 @@
*/
package org.apache.spark.sql.hive.execution
-import org.apache.gluten.execution.datasource.GlutenOrcWriterInjects
-import org.apache.gluten.execution.datasource.GlutenParquetWriterInjects
+import org.apache.gluten.execution.datasource.GlutenFormatFactory
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config.SPECULATION_ENABLED
@@ -114,13 +113,7 @@ class HiveFileFormat(fileSinkConf: FileSinkDesc)
orcOptions.compressionCodec
}
- val nativeConf = if (isParquetFormat) {
- logInfo("Use Gluten parquet write for hive")
- GlutenParquetWriterInjects.getInstance().nativeConf(options,
compressionCodec)
- } else {
- logInfo("Use Gluten orc write for hive")
- GlutenOrcWriterInjects.getInstance().nativeConf(options,
compressionCodec)
- }
+ val nativeConf = GlutenFormatFactory(nativeFormat).nativeConf(options,
compressionCodec)
new OutputWriterFactory {
private val jobConf = new SerializableJobConf(new JobConf(conf))
@@ -135,15 +128,8 @@ class HiveFileFormat(fileSinkConf: FileSinkDesc)
path: String,
dataSchema: StructType,
context: TaskAttemptContext): OutputWriter = {
- if (isParquetFormat) {
- GlutenParquetWriterInjects
- .getInstance()
- .createOutputWriter(path, dataSchema, context, nativeConf);
- } else {
- GlutenOrcWriterInjects
- .getInstance()
- .createOutputWriter(path, dataSchema, context, nativeConf);
- }
+ GlutenFormatFactory(nativeFormat)
+ .createOutputWriter(path, dataSchema, context, nativeConf)
}
}
} else {
diff --git
a/shims/spark33/src/main/scala/org/apache/gluten/sql/shims/spark33/Spark33Shims.scala
b/shims/spark33/src/main/scala/org/apache/gluten/sql/shims/spark33/Spark33Shims.scala
index 5cf7c5505a..2135780d05 100644
---
a/shims/spark33/src/main/scala/org/apache/gluten/sql/shims/spark33/Spark33Shims.scala
+++
b/shims/spark33/src/main/scala/org/apache/gluten/sql/shims/spark33/Spark33Shims.scala
@@ -16,7 +16,7 @@
*/
package org.apache.gluten.sql.shims.spark33
-import org.apache.gluten.execution.datasource.GlutenParquetWriterInjects
+import org.apache.gluten.execution.datasource.GlutenFormatFactory
import org.apache.gluten.expression.{ExpressionNames, Sig}
import org.apache.gluten.expression.ExpressionNames.{CEIL, FLOOR,
KNOWN_NULLABLE, TIMESTAMP_ADD}
import org.apache.gluten.sql.shims.{ShimDescriptor, SparkShims}
@@ -117,7 +117,7 @@ class Spark33Shims extends SparkShims {
options: CaseInsensitiveStringMap,
partitionFilters: Seq[Expression],
dataFilters: Seq[Expression]): TextScan = {
- new TextScan(
+ TextScan(
sparkSession,
fileIndex,
dataSchema,
@@ -217,7 +217,7 @@ class Spark33Shims extends SparkShims {
file: PartitionedFile,
metadataColumnNames: Seq[String]): JMap[String, String] = {
val metadataColumn = new JHashMap[String, String]()
- val path = new Path(file.filePath.toString)
+ val path = new Path(file.filePath)
for (columnName <- metadataColumnNames) {
columnName match {
case FileFormat.FILE_PATH => metadataColumn.put(FileFormat.FILE_PATH,
path.toString)
@@ -246,7 +246,7 @@ class Spark33Shims extends SparkShims {
}
override def getExtendedColumnarPostRules(): List[SparkSession =>
Rule[SparkPlan]] = {
- List(session =>
GlutenParquetWriterInjects.getInstance().getExtendedColumnarPostRule(session))
+ List(session => GlutenFormatFactory.getExtendedColumnarPostRule(session))
}
override def createTestTaskContext(properties: Properties): TaskContext = {
diff --git
a/shims/spark33/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala
b/shims/spark33/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala
index f4215ac136..fd82b34e14 100644
---
a/shims/spark33/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala
+++
b/shims/spark33/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala
@@ -16,7 +16,7 @@
*/
package org.apache.spark.sql.execution.datasources
-import org.apache.gluten.execution.datasource.GlutenRowSplitter
+import org.apache.gluten.execution.datasource.GlutenFormatFactory
import org.apache.spark.internal.Logging
import org.apache.spark.internal.io.{FileCommitProtocol, FileNameSpec}
@@ -410,7 +410,7 @@ class DynamicPartitionDataSingleWriter(
record match {
case fakeRow: FakeRow =>
if (fakeRow.batch.numRows() > 0) {
- val blockStripes = GlutenRowSplitter.getInstance
+ val blockStripes = GlutenFormatFactory.rowSplitter
.splitBlockByPartitionAndBucket(fakeRow, partitionColIndice,
isBucketed)
val iter = blockStripes.iterator()
diff --git
a/shims/spark33/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
b/shims/spark33/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
index f5e932337c..0fdc795135 100644
---
a/shims/spark33/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
+++
b/shims/spark33/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
@@ -16,8 +16,7 @@
*/
package org.apache.spark.sql.execution.datasources
-import org.apache.gluten.execution.datasource.GlutenOrcWriterInjects
-import org.apache.gluten.execution.datasource.GlutenParquetWriterInjects
+import org.apache.gluten.execution.datasource.GlutenFormatFactory
import org.apache.spark._
import org.apache.spark.internal.Logging
@@ -51,7 +50,7 @@ import java.util.{Date, UUID}
/** A helper object for writing FileFormat data out to a location. */
object FileFormatWriter extends Logging {
- var executeWriterWrappedSparkPlan: SparkPlan => RDD[InternalRow] = null
+ var executeWriterWrappedSparkPlan: SparkPlan => RDD[InternalRow] = _
/** Describes how output files should be placed in the filesystem. */
case class OutputSpec(
@@ -277,11 +276,7 @@ object FileFormatWriter extends Logging {
}
val nativeFormat =
sparkSession.sparkContext.getLocalProperty("nativeFormat")
- if ("parquet" == nativeFormat) {
-
(GlutenParquetWriterInjects.getInstance().executeWriterWrappedSparkPlan(wrapped),
None)
- } else {
-
(GlutenOrcWriterInjects.getInstance().executeWriterWrappedSparkPlan(wrapped),
None)
- }
+
(GlutenFormatFactory(nativeFormat).executeWriterWrappedSparkPlan(wrapped), None)
}
try {
diff --git
a/shims/spark33/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala
b/shims/spark33/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala
index 979fe9faf4..0676f12f1f 100644
---
a/shims/spark33/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala
+++
b/shims/spark33/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala
@@ -16,7 +16,7 @@
*/
package org.apache.spark.sql.execution.datasources.orc
-import org.apache.gluten.execution.datasource.GlutenOrcWriterInjects
+import org.apache.gluten.execution.datasource.GlutenFormatFactory
import org.apache.spark.TaskContext
import org.apache.spark.sql.SparkSession
@@ -85,9 +85,7 @@ class OrcFileFormat extends FileFormat with
DataSourceRegister with Serializable
if ("true" ==
sparkSession.sparkContext.getLocalProperty("isNativeApplicable")) {
// pass compression to job conf so that the file extension can be aware
of it.
val nativeConf =
- GlutenOrcWriterInjects
- .getInstance()
- .nativeConf(options, orcOptions.compressionCodec)
+ GlutenFormatFactory(shortName()).nativeConf(options,
orcOptions.compressionCodec)
new OutputWriterFactory {
override def getFileExtension(context: TaskAttemptContext): String = {
@@ -103,10 +101,8 @@ class OrcFileFormat extends FileFormat with
DataSourceRegister with Serializable
path: String,
dataSchema: StructType,
context: TaskAttemptContext): OutputWriter = {
- GlutenOrcWriterInjects
- .getInstance()
- .createOutputWriter(path, dataSchema, context, nativeConf);
-
+ GlutenFormatFactory(shortName())
+ .createOutputWriter(path, dataSchema, context, nativeConf)
}
}
} else {
@@ -155,7 +151,7 @@ class OrcFileFormat extends FileFormat with
DataSourceRegister with Serializable
requiredSchema: StructType,
filters: Seq[Filter],
options: Map[String, String],
- hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] =
{
+ hadoopConf: Configuration): PartitionedFile => Iterator[InternalRow] = {
val resultSchema = StructType(requiredSchema.fields ++
partitionSchema.fields)
val sqlConf = sparkSession.sessionState.conf
diff --git
a/shims/spark33/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala
b/shims/spark33/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala
index 7064f1a6f2..462627bfb7 100644
---
a/shims/spark33/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala
+++
b/shims/spark33/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala
@@ -16,7 +16,7 @@
*/
package org.apache.spark.sql.execution.datasources.parquet
-import org.apache.gluten.execution.datasource.GlutenParquetWriterInjects
+import org.apache.gluten.execution.datasource.GlutenFormatFactory
import org.apache.spark.TaskContext
import org.apache.spark.internal.Logging
@@ -81,8 +81,7 @@ class ParquetFileFormat extends FileFormat with
DataSourceRegister with Logging
val parquetOptions = new ParquetOptions(options,
sparkSession.sessionState.conf)
conf.set(ParquetOutputFormat.COMPRESSION,
parquetOptions.compressionCodecClassName)
val nativeConf =
- GlutenParquetWriterInjects
- .getInstance()
+ GlutenFormatFactory(shortName())
.nativeConf(options, parquetOptions.compressionCodecClassName)
new OutputWriterFactory {
@@ -94,9 +93,8 @@ class ParquetFileFormat extends FileFormat with
DataSourceRegister with Logging
path: String,
dataSchema: StructType,
context: TaskAttemptContext): OutputWriter = {
- GlutenParquetWriterInjects
- .getInstance()
- .createOutputWriter(path, dataSchema, context, nativeConf);
+ GlutenFormatFactory(shortName())
+ .createOutputWriter(path, dataSchema, context, nativeConf)
}
}
@@ -233,7 +231,7 @@ class ParquetFileFormat extends FileFormat with
DataSourceRegister with Logging
requiredSchema: StructType,
filters: Seq[Filter],
options: Map[String, String],
- hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] =
{
+ hadoopConf: Configuration): PartitionedFile => Iterator[InternalRow] = {
hadoopConf.set(ParquetInputFormat.READ_SUPPORT_CLASS,
classOf[ParquetReadSupport].getName)
hadoopConf.set(ParquetReadSupport.SPARK_ROW_REQUESTED_SCHEMA,
requiredSchema.json)
hadoopConf.set(ParquetWriteSupport.SPARK_ROW_SCHEMA, requiredSchema.json)
@@ -323,7 +321,7 @@ class ParquetFileFormat extends FileFormat with
DataSourceRegister with Logging
// have different writers.
// Define isCreatedByParquetMr as function to avoid unnecessary parquet
footer reads.
def isCreatedByParquetMr: Boolean =
- footerFileMetaData.getCreatedBy().startsWith("parquet-mr")
+ footerFileMetaData.getCreatedBy.startsWith("parquet-mr")
val convertTz =
if (timestampConversion && !isCreatedByParquetMr) {
@@ -514,7 +512,7 @@ object ParquetFileFormat extends Logging {
// when it can't read the footer.
Some(
new Footer(
- currentFile.getPath(),
+ currentFile.getPath,
ParquetFooterReader.readFooter(conf, currentFile,
SKIP_ROW_GROUPS)))
} catch {
case e: RuntimeException =>
diff --git
a/shims/spark33/src/main/scala/org/apache/spark/sql/hive/execution/HiveFileFormat.scala
b/shims/spark33/src/main/scala/org/apache/spark/sql/hive/execution/HiveFileFormat.scala
index b9c1622cbe..6ed1b4d215 100644
---
a/shims/spark33/src/main/scala/org/apache/spark/sql/hive/execution/HiveFileFormat.scala
+++
b/shims/spark33/src/main/scala/org/apache/spark/sql/hive/execution/HiveFileFormat.scala
@@ -16,8 +16,7 @@
*/
package org.apache.spark.sql.hive.execution
-import org.apache.gluten.execution.datasource.GlutenOrcWriterInjects
-import org.apache.gluten.execution.datasource.GlutenParquetWriterInjects
+import org.apache.gluten.execution.datasource.GlutenFormatFactory
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config.SPECULATION_ENABLED
@@ -111,13 +110,7 @@ class HiveFileFormat(fileSinkConf: FileSinkDesc)
orcOptions.compressionCodec
}
- val nativeConf = if (isParquetFormat) {
- logInfo("Use Gluten parquet write for hive")
- GlutenParquetWriterInjects.getInstance().nativeConf(options,
compressionCodec)
- } else {
- logInfo("Use Gluten orc write for hive")
- GlutenOrcWriterInjects.getInstance().nativeConf(options,
compressionCodec)
- }
+ val nativeConf = GlutenFormatFactory(nativeFormat).nativeConf(options,
compressionCodec)
new OutputWriterFactory {
private val jobConf = new SerializableJobConf(new JobConf(conf))
@@ -132,15 +125,8 @@ class HiveFileFormat(fileSinkConf: FileSinkDesc)
path: String,
dataSchema: StructType,
context: TaskAttemptContext): OutputWriter = {
- if (isParquetFormat) {
- GlutenParquetWriterInjects
- .getInstance()
- .createOutputWriter(path, dataSchema, context, nativeConf);
- } else {
- GlutenOrcWriterInjects
- .getInstance()
- .createOutputWriter(path, dataSchema, context, nativeConf);
- }
+ GlutenFormatFactory(nativeFormat)
+ .createOutputWriter(path, dataSchema, context, nativeConf)
}
}
} else {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]