This is an automated email from the ASF dual-hosted git repository.
zhangzc pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new 1fbdbc417 [GLUTEN-6067][CH] [Part 2] Support CH backend with Spark3.5
- Prepare for supporting sink transform (#6197)
1fbdbc417 is described below
commit 1fbdbc41779321db3380bce0807b73389af64e1a
Author: Chang chen <[email protected]>
AuthorDate: Tue Jun 25 07:13:39 2024 +0800
[GLUTEN-6067][CH] [Part 2] Support CH backend with Spark3.5 - Prepare for
supporting sink transform (#6197)
[CH] [Part 2] Support CH backend with Spark3.5 - Prepare for supporting
sink transform
* [Refactor] remove duplicate codes
* Add NativeWriteChecker
* [Prepare to commit] getExtendedColumnarPostRules from Spark shim
---
.../backendsapi/clickhouse/CHIteratorApi.scala | 143 +++--
.../clickhouse/CHSparkPlanExecApi.scala | 9 -
.../execution/CHHashJoinExecTransformer.scala | 3 +-
.../GlutenClickHouseNativeWriteTableSuite.scala | 612 +++++++++------------
.../metrics/GlutenClickHouseTPCHMetricsSuite.scala | 2 +-
.../apache/spark/gluten/NativeWriteChecker.scala | 52 ++
.../backendsapi/velox/VeloxSparkPlanExecApi.scala | 9 -
cpp-ch/local-engine/Common/CHUtil.cpp | 17 +-
cpp-ch/local-engine/Common/CHUtil.h | 12 +-
cpp-ch/local-engine/Parser/CHColumnToSparkRow.cpp | 2 +-
.../local-engine/Parser/SerializedPlanParser.cpp | 310 ++++-------
cpp-ch/local-engine/Parser/SerializedPlanParser.h | 39 +-
cpp-ch/local-engine/local_engine_jni.cpp | 39 +-
.../local-engine/tests/benchmark_local_engine.cpp | 80 +--
cpp-ch/local-engine/tests/gluten_test_util.h | 18 +
cpp-ch/local-engine/tests/gtest_local_engine.cpp | 22 +-
cpp-ch/local-engine/tests/gtest_parser.cpp | 407 ++++----------
.../clickhouse_pr_65234.json} | 49 +-
.../tests/json/gtest_local_engine_config.json | 269 +++++++++
.../tests/json/read_student_option_schema.csv.json | 77 +++
.../gluten/backendsapi/SparkPlanExecApi.scala | 4 +-
.../gluten/utils/SubstraitPlanPrinterUtil.scala | 35 +-
22 files changed, 1107 insertions(+), 1103 deletions(-)
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHIteratorApi.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHIteratorApi.scala
index 941237629..376e46ebe 100644
---
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHIteratorApi.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHIteratorApi.scala
@@ -16,7 +16,7 @@
*/
package org.apache.gluten.backendsapi.clickhouse
-import org.apache.gluten.{GlutenConfig, GlutenNumaBindingInfo}
+import org.apache.gluten.GlutenNumaBindingInfo
import org.apache.gluten.backendsapi.IteratorApi
import org.apache.gluten.execution._
import org.apache.gluten.expression.ConverterUtils
@@ -61,6 +61,52 @@ class CHIteratorApi extends IteratorApi with Logging with
LogLevelUtil {
StructType(dataSchema)
}
+ private def createNativeIterator(
+ splitInfoByteArray: Array[Array[Byte]],
+ wsPlan: Array[Byte],
+ materializeInput: Boolean,
+ inputIterators: Seq[Iterator[ColumnarBatch]]): BatchIterator = {
+
+ /** Generate closeable ColumnBatch iterator. */
+ val listIterator =
+ inputIterators
+ .map {
+ case i: CloseableCHColumnBatchIterator => i
+ case it => new CloseableCHColumnBatchIterator(it)
+ }
+ .map(it => new
ColumnarNativeIterator(it.asJava).asInstanceOf[GeneralInIterator])
+ .asJava
+ new CHNativeExpressionEvaluator().createKernelWithBatchIterator(
+ wsPlan,
+ splitInfoByteArray,
+ listIterator,
+ materializeInput
+ )
+ }
+
+ private def createCloseIterator(
+ context: TaskContext,
+ pipelineTime: SQLMetric,
+ updateNativeMetrics: IMetrics => Unit,
+ updateInputMetrics: Option[InputMetricsWrapper => Unit] = None,
+ nativeIter: BatchIterator): CloseableCHColumnBatchIterator = {
+
+ val iter = new CollectMetricIterator(
+ nativeIter,
+ updateNativeMetrics,
+ updateInputMetrics,
+ updateInputMetrics.map(_ => context.taskMetrics().inputMetrics).orNull)
+
+ context.addTaskFailureListener(
+ (ctx, _) => {
+ if (ctx.isInterrupted()) {
+ iter.cancel()
+ }
+ })
+ context.addTaskCompletionListener[Unit](_ => iter.close())
+ new CloseableCHColumnBatchIterator(iter, Some(pipelineTime))
+ }
+
// only set file schema for text format table
private def setFileSchemaForLocalFiles(
localFilesNode: LocalFilesNode,
@@ -198,45 +244,24 @@ class CHIteratorApi extends IteratorApi with Logging with
LogLevelUtil {
inputIterators: Seq[Iterator[ColumnarBatch]] = Seq()
): Iterator[ColumnarBatch] = {
- assert(
+ require(
inputPartition.isInstanceOf[GlutenPartition],
"CH backend only accepts GlutenPartition in
GlutenWholeStageColumnarRDD.")
-
- val transKernel = new CHNativeExpressionEvaluator()
- val inBatchIters = new JArrayList[GeneralInIterator](inputIterators.map {
- iter => new
ColumnarNativeIterator(CHIteratorApi.genCloseableColumnBatchIterator(iter).asJava)
- }.asJava)
-
val splitInfoByteArray = inputPartition
.asInstanceOf[GlutenPartition]
.splitInfosByteArray
- val nativeIter =
- transKernel.createKernelWithBatchIterator(
- inputPartition.plan,
- splitInfoByteArray,
- inBatchIters,
- false)
+ val wsPlan = inputPartition.plan
+ val materializeInput = false
- val iter = new CollectMetricIterator(
- nativeIter,
- updateNativeMetrics,
- updateInputMetrics,
- context.taskMetrics().inputMetrics)
-
- context.addTaskFailureListener(
- (ctx, _) => {
- if (ctx.isInterrupted()) {
- iter.cancel()
- }
- })
- context.addTaskCompletionListener[Unit](_ => iter.close())
-
- // TODO: SPARK-25083 remove the type erasure hack in data source scan
new InterruptibleIterator(
context,
- new CloseableCHColumnBatchIterator(
- iter.asInstanceOf[Iterator[ColumnarBatch]],
- Some(pipelineTime)))
+ createCloseIterator(
+ context,
+ pipelineTime,
+ updateNativeMetrics,
+ Some(updateInputMetrics),
+ createNativeIterator(splitInfoByteArray, wsPlan, materializeInput,
inputIterators))
+ )
}
// Generate Iterator[ColumnarBatch] for final stage.
@@ -252,52 +277,26 @@ class CHIteratorApi extends IteratorApi with Logging with
LogLevelUtil {
partitionIndex: Int,
materializeInput: Boolean): Iterator[ColumnarBatch] = {
// scalastyle:on argcount
- GlutenConfig.getConf
-
- val transKernel = new CHNativeExpressionEvaluator()
- val columnarNativeIterator =
- new JArrayList[GeneralInIterator](inputIterators.map {
- iter =>
- new
ColumnarNativeIterator(CHIteratorApi.genCloseableColumnBatchIterator(iter).asJava)
- }.asJava)
- // we need to complete dependency RDD's firstly
- val nativeIterator = transKernel.createKernelWithBatchIterator(
- rootNode.toProtobuf.toByteArray,
- // Final iterator does not contain scan split, so pass empty split info
to native here.
- new Array[Array[Byte]](0),
- columnarNativeIterator,
- materializeInput
- )
-
- val iter = new CollectMetricIterator(nativeIterator, updateNativeMetrics,
null, null)
- context.addTaskFailureListener(
- (ctx, _) => {
- if (ctx.isInterrupted()) {
- iter.cancel()
- }
- })
- context.addTaskCompletionListener[Unit](_ => iter.close())
- new CloseableCHColumnBatchIterator(iter, Some(pipelineTime))
- }
-}
+ // Final iterator does not contain scan split, so pass empty split info to
native here.
+ val splitInfoByteArray = new Array[Array[Byte]](0)
+ val wsPlan = rootNode.toProtobuf.toByteArray
-object CHIteratorApi {
-
- /** Generate closeable ColumnBatch iterator. */
- def genCloseableColumnBatchIterator(iter: Iterator[ColumnarBatch]):
Iterator[ColumnarBatch] = {
- iter match {
- case _: CloseableCHColumnBatchIterator => iter
- case _ => new CloseableCHColumnBatchIterator(iter)
- }
+ // we need to complete dependency RDD's firstly
+ createCloseIterator(
+ context,
+ pipelineTime,
+ updateNativeMetrics,
+ None,
+ createNativeIterator(splitInfoByteArray, wsPlan, materializeInput,
inputIterators))
}
}
class CollectMetricIterator(
val nativeIterator: BatchIterator,
val updateNativeMetrics: IMetrics => Unit,
- val updateInputMetrics: InputMetricsWrapper => Unit,
- val inputMetrics: InputMetrics
+ val updateInputMetrics: Option[InputMetricsWrapper => Unit] = None,
+ val inputMetrics: InputMetrics = null
) extends Iterator[ColumnarBatch] {
private var outputRowCount = 0L
private var outputVectorCount = 0L
@@ -329,9 +328,7 @@ class CollectMetricIterator(
val nativeMetrics = nativeIterator.getMetrics.asInstanceOf[NativeMetrics]
nativeMetrics.setFinalOutputMetrics(outputRowCount, outputVectorCount)
updateNativeMetrics(nativeMetrics)
- if (updateInputMetrics != null) {
- updateInputMetrics(inputMetrics)
- }
+ updateInputMetrics.foreach(_(inputMetrics))
metricsUpdated = true
}
}
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
index 1c83e326e..ac3ea61ff 100644
---
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
@@ -50,7 +50,6 @@ import org.apache.spark.sql.delta.files.TahoeFileIndex
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.AQEShuffleReadExec
import org.apache.spark.sql.execution.datasources.{FileFormat,
HadoopFsRelation}
-import
org.apache.spark.sql.execution.datasources.GlutenWriterColumnarRules.NativeWritePostRule
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
import
org.apache.spark.sql.execution.datasources.v2.clickhouse.source.DeltaMergeTreeFileFormat
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec,
ShuffleExchangeExec}
@@ -583,14 +582,6 @@ class CHSparkPlanExecApi extends SparkPlanExecApi {
override def genExtendedColumnarTransformRules(): List[SparkSession =>
Rule[SparkPlan]] =
List()
- /**
- * Generate extended columnar post-rules.
- *
- * @return
- */
- override def genExtendedColumnarPostRules(): List[SparkSession =>
Rule[SparkPlan]] =
- List(spark => NativeWritePostRule(spark))
-
override def genInjectPostHocResolutionRules(): List[SparkSession =>
Rule[LogicalPlan]] = {
List()
}
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala
index a7e7769e7..da9d9c758 100644
---
a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala
@@ -16,7 +16,6 @@
*/
package org.apache.gluten.execution
-import org.apache.gluten.backendsapi.clickhouse.CHIteratorApi
import org.apache.gluten.extension.ValidationResult
import org.apache.gluten.utils.{BroadcastHashJoinStrategy, CHJoinValidateUtil,
ShuffleHashJoinStrategy}
@@ -75,7 +74,7 @@ case class CHBroadcastBuildSideRDD(
override def genBroadcastBuildSideIterator(): Iterator[ColumnarBatch] = {
CHBroadcastBuildSideCache.getOrBuildBroadcastHashTable(broadcasted,
broadcastContext)
- CHIteratorApi.genCloseableColumnBatchIterator(Iterator.empty)
+ Iterator.empty
}
}
diff --git
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseNativeWriteTableSuite.scala
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseNativeWriteTableSuite.scala
index 9269303d9..ccf7bb5d5 100644
---
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseNativeWriteTableSuite.scala
+++
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseNativeWriteTableSuite.scala
@@ -21,6 +21,7 @@ import
org.apache.gluten.execution.AllDataTypesWithComplexType.genTestData
import org.apache.gluten.utils.UTSystemParameters
import org.apache.spark.SparkConf
+import org.apache.spark.gluten.NativeWriteChecker
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.delta.DeltaLog
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
@@ -28,11 +29,14 @@ import org.apache.spark.sql.test.SharedSparkSession
import org.scalatest.BeforeAndAfterAll
+import scala.reflect.runtime.universe.TypeTag
+
class GlutenClickHouseNativeWriteTableSuite
extends GlutenClickHouseWholeStageTransformerSuite
with AdaptiveSparkPlanHelper
with SharedSparkSession
- with BeforeAndAfterAll {
+ with BeforeAndAfterAll
+ with NativeWriteChecker {
private var _hiveSpark: SparkSession = _
@@ -114,16 +118,19 @@ class GlutenClickHouseNativeWriteTableSuite
def getColumnName(s: String): String = {
s.replaceAll("\\(", "_").replaceAll("\\)", "_")
}
+
import collection.immutable.ListMap
import java.io.File
def writeIntoNewTableWithSql(table_name: String, table_create_sql: String)(
fields: Seq[String]): Unit = {
- spark.sql(table_create_sql)
- spark.sql(
- s"insert overwrite $table_name select ${fields.mkString(",")}" +
- s" from origin_table")
+ withDestinationTable(table_name, table_create_sql) {
+ checkNativeWrite(
+ s"insert overwrite $table_name select ${fields.mkString(",")}" +
+ s" from origin_table",
+ checkNative = true)
+ }
}
def writeAndCheckRead(
@@ -170,82 +177,86 @@ class GlutenClickHouseNativeWriteTableSuite
})
}
- test("test insert into dir") {
- withSQLConf(
- ("spark.gluten.sql.native.writer.enabled", "true"),
- (GlutenConfig.GLUTEN_ENABLED.key, "true")) {
-
- val originDF = spark.createDataFrame(genTestData())
- originDF.createOrReplaceTempView("origin_table")
+ private val fields_ = ListMap(
+ ("string_field", "string"),
+ ("int_field", "int"),
+ ("long_field", "long"),
+ ("float_field", "float"),
+ ("double_field", "double"),
+ ("short_field", "short"),
+ ("byte_field", "byte"),
+ ("boolean_field", "boolean"),
+ ("decimal_field", "decimal(23,12)"),
+ ("date_field", "date")
+ )
- val fields: ListMap[String, String] = ListMap(
- ("string_field", "string"),
- ("int_field", "int"),
- ("long_field", "long"),
- ("float_field", "float"),
- ("double_field", "double"),
- ("short_field", "short"),
- ("byte_field", "byte"),
- ("boolean_field", "boolean"),
- ("decimal_field", "decimal(23,12)"),
- ("date_field", "date")
- )
+ def withDestinationTable(table: String, createTableSql: String)(f: => Unit):
Unit = {
+ spark.sql(s"drop table IF EXISTS $table")
+ spark.sql(s"$createTableSql")
+ f
+ }
- for (format <- formats) {
- spark.sql(
- s"insert overwrite local directory
'$basePath/test_insert_into_${format}_dir1' "
- + s"stored as $format select "
- + fields.keys.mkString(",") +
- " from origin_table cluster by (byte_field)")
- spark.sql(
- s"insert overwrite local directory
'$basePath/test_insert_into_${format}_dir2' " +
- s"stored as $format " +
- "select string_field, sum(int_field) as x from origin_table group
by string_field")
- }
+ def nativeWrite(f: String => Unit): Unit = {
+ withSQLConf(("spark.gluten.sql.native.writer.enabled", "true")) {
+ formats.foreach(f(_))
}
}
- test("test insert into partition") {
- withSQLConf(
- ("spark.gluten.sql.native.writer.enabled", "true"),
- ("spark.sql.orc.compression.codec", "lz4"),
- (GlutenConfig.GLUTEN_ENABLED.key, "true")) {
-
- val originDF = spark.createDataFrame(genTestData())
- originDF.createOrReplaceTempView("origin_table")
-
- val fields: ListMap[String, String] = ListMap(
- ("string_field", "string"),
- ("int_field", "int"),
- ("long_field", "long"),
- ("float_field", "float"),
- ("double_field", "double"),
- ("short_field", "short"),
- ("byte_field", "byte"),
- ("boolean_field", "boolean"),
- ("decimal_field", "decimal(23,12)"),
- ("date_field", "date")
- )
-
- for (format <- formats) {
- val table_name = table_name_template.format(format)
- spark.sql(s"drop table IF EXISTS $table_name")
+ def nativeWrite2(
+ f: String => (String, String, String),
+ extraCheck: (String, String, String) => Unit = null): Unit = nativeWrite
{
+ format =>
+ val (table_name, table_create_sql, insert_sql) = f(format)
+ withDestinationTable(table_name, table_create_sql) {
+ checkNativeWrite(insert_sql, checkNative = true)
+ Option(extraCheck).foreach(_(table_name, table_create_sql, insert_sql))
+ }
+ }
- val table_create_sql =
- s"create table if not exists $table_name (" +
- fields
- .map(f => s"${f._1} ${f._2}")
- .mkString(",") +
- " ) partitioned by (another_date_field date) " +
- s"stored as $format"
+ def nativeWriteWithOriginalView[A <: Product: TypeTag](
+ data: Seq[A],
+ viewName: String,
+ pairs: (String, String)*)(f: String => Unit): Unit = {
+ val configs = pairs :+ ("spark.gluten.sql.native.writer.enabled", "true")
+ withSQLConf(configs: _*) {
+ withTempView(viewName) {
+ spark.createDataFrame(data).createOrReplaceTempView(viewName)
+ formats.foreach(f(_))
+ }
+ }
+ }
- spark.sql(table_create_sql)
+ test("test insert into dir") {
+ nativeWriteWithOriginalView(genTestData(), "origin_table") {
+ format =>
+ Seq(
+ s"""insert overwrite local directory
'$basePath/test_insert_into_${format}_dir1'
+ |stored as $format select ${fields_.keys.mkString(",")}
+ |from origin_table""".stripMargin,
+ s"""insert overwrite local directory
'$basePath/test_insert_into_${format}_dir2'
+ |stored as $format select string_field, sum(int_field) as x
+ |from origin_table group by string_field""".stripMargin
+ ).foreach(checkNativeWrite(_, checkNative = true))
+ }
+ }
- spark.sql(
- s"insert into $table_name partition(another_date_field =
'2020-01-01') select "
- + fields.keys.mkString(",") +
- " from origin_table")
+ test("test insert into partition") {
+ def destination(format: String): (String, String, String) = {
+ val table_name = table_name_template.format(format)
+ val table_create_sql =
+ s"""create table if not exists $table_name
+ |(${fields_.map(f => s"${f._1} ${f._2}").mkString(",")})
+ |partitioned by (another_date_field date) stored as
$format""".stripMargin
+ val insert_sql =
+ s"""insert into $table_name partition(another_date_field =
'2020-01-01')
+ | select ${fields_.keys.mkString(",")} from
origin_table""".stripMargin
+ (table_name, table_create_sql, insert_sql)
+ }
+ def nativeFormatWrite(format: String): Unit = {
+ val (table_name, table_create_sql, insert_sql) = destination(format)
+ withDestinationTable(table_name, table_create_sql) {
+ checkNativeWrite(insert_sql, checkNative = true)
var files = recursiveListFiles(new File(getWarehouseDir + "/" +
table_name))
.filter(_.getName.endsWith(s".$format"))
if (format == "orc") {
@@ -255,154 +266,103 @@ class GlutenClickHouseNativeWriteTableSuite
assert(files.head.getAbsolutePath.contains("another_date_field=2020-01-01"))
}
}
+
+ nativeWriteWithOriginalView(
+ genTestData(),
+ "origin_table",
+ ("spark.sql.orc.compression.codec", "lz4"))(nativeFormatWrite)
}
test("test CTAS") {
- withSQLConf(
- ("spark.gluten.sql.native.writer.enabled", "true"),
- (GlutenConfig.GLUTEN_ENABLED.key, "true")) {
-
- val originDF = spark.createDataFrame(genTestData())
- originDF.createOrReplaceTempView("origin_table")
- val fields: ListMap[String, String] = ListMap(
- ("string_field", "string"),
- ("int_field", "int"),
- ("long_field", "long"),
- ("float_field", "float"),
- ("double_field", "double"),
- ("short_field", "short"),
- ("byte_field", "byte"),
- ("boolean_field", "boolean"),
- ("decimal_field", "decimal(23,12)"),
- ("date_field", "date")
- )
-
- for (format <- formats) {
+ nativeWriteWithOriginalView(genTestData(), "origin_table") {
+ format =>
val table_name = table_name_template.format(format)
- spark.sql(s"drop table IF EXISTS $table_name")
val table_create_sql =
s"create table $table_name using $format as select " +
- fields
+ fields_
.map(f => s"${f._1}")
.mkString(",") +
" from origin_table"
- spark.sql(table_create_sql)
- spark.sql(s"drop table IF EXISTS $table_name")
+ val insert_sql =
+ s"create table $table_name as select " +
+ fields_
+ .map(f => s"${f._1}")
+ .mkString(",") +
+ " from origin_table"
+ withDestinationTable(table_name, table_create_sql) {
+ spark.sql(s"drop table IF EXISTS $table_name")
- try {
- val table_create_sql =
- s"create table $table_name as select " +
- fields
- .map(f => s"${f._1}")
- .mkString(",") +
- " from origin_table"
- spark.sql(table_create_sql)
- } catch {
- case _: UnsupportedOperationException => // expected
- case _: Exception => fail("should not throw exception")
+ try {
+ // FIXME: using checkNativeWrite
+ spark.sql(insert_sql)
+ } catch {
+ case _: UnsupportedOperationException => // expected
+ case e: Exception => fail("should not throw exception", e)
+ }
}
- }
}
}
test("test insert into partition, bigo's case which incur
InsertIntoHiveTable") {
- withSQLConf(
- ("spark.gluten.sql.native.writer.enabled", "true"),
- ("spark.sql.hive.convertMetastoreParquet", "false"),
- ("spark.sql.hive.convertMetastoreOrc", "false"),
- (GlutenConfig.GLUTEN_ENABLED.key, "true")
- ) {
-
- val originDF = spark.createDataFrame(genTestData())
- originDF.createOrReplaceTempView("origin_table")
- val fields: ListMap[String, String] = ListMap(
- ("string_field", "string"),
- ("int_field", "int"),
- ("long_field", "long"),
- ("float_field", "float"),
- ("double_field", "double"),
- ("short_field", "short"),
- ("byte_field", "byte"),
- ("boolean_field", "boolean"),
- ("decimal_field", "decimal(23,12)"),
- ("date_field", "date")
- )
-
- for (format <- formats) {
- val table_name = table_name_template.format(format)
- spark.sql(s"drop table IF EXISTS $table_name")
- val table_create_sql = s"create table if not exists $table_name (" +
fields
- .map(f => s"${f._1} ${f._2}")
- .mkString(",") + " ) partitioned by (another_date_field string)" +
- s"stored as $format"
+ def destination(format: String): (String, String, String) = {
+ val table_name = table_name_template.format(format)
+ val table_create_sql = s"create table if not exists $table_name (" +
fields_
+ .map(f => s"${f._1} ${f._2}")
+ .mkString(",") + " ) partitioned by (another_date_field string)" +
+ s"stored as $format"
+ val insert_sql =
+ s"insert overwrite table $table_name " +
+ "partition(another_date_field = '2020-01-01') select " +
+ fields_.keys.mkString(",") + " from (select " +
fields_.keys.mkString(
+ ",") + ", row_number() over (order by int_field desc) as rn " +
+ "from origin_table where float_field > 3 ) tt where rn <= 100"
+ (table_name, table_create_sql, insert_sql)
+ }
- spark.sql(table_create_sql)
- spark.sql(
- s"insert overwrite table $table_name " +
- "partition(another_date_field = '2020-01-01') select "
- + fields.keys.mkString(",") + " from (select " +
fields.keys.mkString(
- ",") + ", row_number() over (order by int_field desc) as rn " +
- "from origin_table where float_field > 3 ) tt where rn <= 100")
+ def nativeFormatWrite(format: String): Unit = {
+ val (table_name, table_create_sql, insert_sql) = destination(format)
+ withDestinationTable(table_name, table_create_sql) {
+ checkNativeWrite(insert_sql, checkNative = true)
val files = recursiveListFiles(new File(getWarehouseDir + "/" +
table_name))
.filter(_.getName.startsWith("part"))
assert(files.length == 1)
assert(files.head.getAbsolutePath.contains("another_date_field=2020-01-01"))
}
}
+
+ nativeWriteWithOriginalView(
+ genTestData(),
+ "origin_table",
+ ("spark.sql.hive.convertMetastoreParquet", "false"),
+ ("spark.sql.hive.convertMetastoreOrc", "false"))(nativeFormatWrite)
}
test("test 1-col partitioned table") {
+ nativeWrite {
+ format =>
+ {
+ val table_name = table_name_template.format(format)
+ val table_create_sql =
+ s"create table if not exists $table_name (" +
+ fields_
+ .filterNot(e => e._1.equals("date_field"))
+ .map(f => s"${f._1} ${f._2}")
+ .mkString(",") +
+ " ) partitioned by (date_field date) " +
+ s"stored as $format"
- withSQLConf(("spark.gluten.sql.native.writer.enabled", "true")) {
-
- val fields: ListMap[String, String] = ListMap(
- ("string_field", "string"),
- ("int_field", "int"),
- ("long_field", "long"),
- ("float_field", "float"),
- ("double_field", "double"),
- ("short_field", "short"),
- ("byte_field", "byte"),
- ("boolean_field", "boolean"),
- ("decimal_field", "decimal(23,12)"),
- ("date_field", "date")
- )
-
- for (format <- formats) {
- val table_name = table_name_template.format(format)
- val table_create_sql =
- s"create table if not exists $table_name (" +
- fields
- .filterNot(e => e._1.equals("date_field"))
- .map(f => s"${f._1} ${f._2}")
- .mkString(",") +
- " ) partitioned by (date_field date) " +
- s"stored as $format"
-
- writeAndCheckRead(
- table_name,
- writeIntoNewTableWithSql(table_name, table_create_sql),
- fields.keys.toSeq)
- }
+ writeAndCheckRead(
+ table_name,
+ writeIntoNewTableWithSql(table_name, table_create_sql),
+ fields_.keys.toSeq)
+ }
}
}
// even if disable native writer, this UT fail, spark bug???
ignore("test 1-col partitioned table, partitioned by already ordered
column") {
withSQLConf(("spark.gluten.sql.native.writer.enabled", "false")) {
- val fields: ListMap[String, String] = ListMap(
- ("string_field", "string"),
- ("int_field", "int"),
- ("long_field", "long"),
- ("float_field", "float"),
- ("double_field", "double"),
- ("short_field", "short"),
- ("byte_field", "byte"),
- ("boolean_field", "boolean"),
- ("decimal_field", "decimal(23,12)"),
- ("date_field", "date")
- )
val originDF = spark.createDataFrame(genTestData())
originDF.createOrReplaceTempView("origin_table")
@@ -410,7 +370,7 @@ class GlutenClickHouseNativeWriteTableSuite
val table_name = table_name_template.format(format)
val table_create_sql =
s"create table if not exists $table_name (" +
- fields
+ fields_
.filterNot(e => e._1.equals("date_field"))
.map(f => s"${f._1} ${f._2}")
.mkString(",") +
@@ -420,31 +380,27 @@ class GlutenClickHouseNativeWriteTableSuite
spark.sql(s"drop table IF EXISTS $table_name")
spark.sql(table_create_sql)
spark.sql(
- s"insert overwrite $table_name select ${fields.mkString(",")}" +
+ s"insert overwrite $table_name select ${fields_.mkString(",")}" +
s" from origin_table order by date_field")
}
}
}
test("test 2-col partitioned table") {
- withSQLConf(
- ("spark.gluten.sql.native.writer.enabled", "true"),
- (GlutenConfig.GLUTEN_ENABLED.key, "true")) {
-
- val fields: ListMap[String, String] = ListMap(
- ("string_field", "string"),
- ("int_field", "int"),
- ("long_field", "long"),
- ("float_field", "float"),
- ("double_field", "double"),
- ("short_field", "short"),
- ("boolean_field", "boolean"),
- ("decimal_field", "decimal(23,12)"),
- ("date_field", "date"),
- ("byte_field", "byte")
- )
-
- for (format <- formats) {
+ val fields: ListMap[String, String] = ListMap(
+ ("string_field", "string"),
+ ("int_field", "int"),
+ ("long_field", "long"),
+ ("float_field", "float"),
+ ("double_field", "double"),
+ ("short_field", "short"),
+ ("boolean_field", "boolean"),
+ ("decimal_field", "decimal(23,12)"),
+ ("date_field", "date"),
+ ("byte_field", "byte")
+ )
+ nativeWrite {
+ format =>
val table_name = table_name_template.format(format)
val table_create_sql =
s"create table if not exists $table_name (" +
@@ -458,7 +414,6 @@ class GlutenClickHouseNativeWriteTableSuite
table_name,
writeIntoNewTableWithSql(table_name, table_create_sql),
fields.keys.toSeq)
- }
}
}
@@ -506,25 +461,21 @@ class GlutenClickHouseNativeWriteTableSuite
// This test case will be failed with incorrect result randomly, ignore
first.
ignore("test hive parquet/orc table, all columns being partitioned. ") {
- withSQLConf(
- ("spark.gluten.sql.native.writer.enabled", "true"),
- (GlutenConfig.GLUTEN_ENABLED.key, "true")) {
-
- val fields: ListMap[String, String] = ListMap(
- ("date_field", "date"),
- ("timestamp_field", "timestamp"),
- ("string_field", "string"),
- ("int_field", "int"),
- ("long_field", "long"),
- ("float_field", "float"),
- ("double_field", "double"),
- ("short_field", "short"),
- ("byte_field", "byte"),
- ("boolean_field", "boolean"),
- ("decimal_field", "decimal(23,12)")
- )
-
- for (format <- formats) {
+ val fields: ListMap[String, String] = ListMap(
+ ("date_field", "date"),
+ ("timestamp_field", "timestamp"),
+ ("string_field", "string"),
+ ("int_field", "int"),
+ ("long_field", "long"),
+ ("float_field", "float"),
+ ("double_field", "double"),
+ ("short_field", "short"),
+ ("byte_field", "byte"),
+ ("boolean_field", "boolean"),
+ ("decimal_field", "decimal(23,12)")
+ )
+ nativeWrite {
+ format =>
val table_name = table_name_template.format(format)
val table_create_sql =
s"create table if not exists $table_name (" +
@@ -540,20 +491,15 @@ class GlutenClickHouseNativeWriteTableSuite
table_name,
writeIntoNewTableWithSql(table_name, table_create_sql),
fields.keys.toSeq)
- }
}
}
- test(("test hive parquet/orc table with aggregated results")) {
- withSQLConf(
- ("spark.gluten.sql.native.writer.enabled", "true"),
- (GlutenConfig.GLUTEN_ENABLED.key, "true")) {
-
- val fields: ListMap[String, String] = ListMap(
- ("sum(int_field)", "bigint")
- )
-
- for (format <- formats) {
+ test("test hive parquet/orc table with aggregated results") {
+ val fields: ListMap[String, String] = ListMap(
+ ("sum(int_field)", "bigint")
+ )
+ nativeWrite {
+ format =>
val table_name = table_name_template.format(format)
val table_create_sql =
s"create table if not exists $table_name (" +
@@ -566,29 +512,12 @@ class GlutenClickHouseNativeWriteTableSuite
table_name,
writeIntoNewTableWithSql(table_name, table_create_sql),
fields.keys.toSeq)
- }
}
}
test("test 1-col partitioned + 1-col bucketed table") {
- withSQLConf(
- ("spark.gluten.sql.native.writer.enabled", "true"),
- (GlutenConfig.GLUTEN_ENABLED.key, "true")) {
-
- val fields: ListMap[String, String] = ListMap(
- ("string_field", "string"),
- ("int_field", "int"),
- ("long_field", "long"),
- ("float_field", "float"),
- ("double_field", "double"),
- ("short_field", "short"),
- ("byte_field", "byte"),
- ("boolean_field", "boolean"),
- ("decimal_field", "decimal(23,12)"),
- ("date_field", "date")
- )
-
- for (format <- formats) {
+ nativeWrite {
+ format =>
// spark write does not support bucketed table
// https://issues.apache.org/jira/browse/SPARK-19256
val table_name = table_name_template.format(format)
@@ -604,7 +533,7 @@ class GlutenClickHouseNativeWriteTableSuite
.bucketBy(2, "byte_field")
.saveAsTable(table_name)
},
- fields.keys.toSeq
+ fields_.keys.toSeq
)
assert(
@@ -614,10 +543,8 @@ class GlutenClickHouseNativeWriteTableSuite
.filter(!_.getName.equals("date_field=__HIVE_DEFAULT_PARTITION__"))
.head
.listFiles()
- .filter(!_.isHidden)
- .length == 2
+ .count(!_.isHidden) == 2
) // 2 bucket files
- }
}
}
@@ -745,8 +672,8 @@ class GlutenClickHouseNativeWriteTableSuite
}
test("test consecutive blocks having same partition value") {
- withSQLConf(("spark.gluten.sql.native.writer.enabled", "true")) {
- for (format <- formats) {
+ nativeWrite {
+ format =>
val table_name = table_name_template.format(format)
spark.sql(s"drop table IF EXISTS $table_name")
@@ -760,15 +687,14 @@ class GlutenClickHouseNativeWriteTableSuite
.partitionBy("p")
.saveAsTable(table_name)
- val ret = spark.sql("select sum(id) from " +
table_name).collect().apply(0).apply(0)
+ val ret = spark.sql(s"select sum(id) from
$table_name").collect().apply(0).apply(0)
assert(ret == 449985000)
- }
}
}
test("test decimal with rand()") {
- withSQLConf(("spark.gluten.sql.native.writer.enabled", "true")) {
- for (format <- formats) {
+ nativeWrite {
+ format =>
val table_name = table_name_template.format(format)
spark.sql(s"drop table IF EXISTS $table_name")
spark
@@ -778,32 +704,30 @@ class GlutenClickHouseNativeWriteTableSuite
.format(format)
.partitionBy("p")
.saveAsTable(table_name)
- val ret = spark.sql("select max(p) from " +
table_name).collect().apply(0).apply(0)
- }
+ val ret = spark.sql(s"select max(p) from
$table_name").collect().apply(0).apply(0)
}
}
test("test partitioned by constant") {
- withSQLConf(("spark.gluten.sql.native.writer.enabled", "true")) {
- for (format <- formats) {
- spark.sql(s"drop table IF EXISTS tmp_123_$format")
- spark.sql(
- s"create table tmp_123_$format(" +
- s"x1 string, x2 bigint,x3 string, x4 bigint, x5 string )" +
- s"partitioned by (day date) stored as $format")
-
- spark.sql(
- s"insert into tmp_123_$format partition(day) " +
- "select cast(id as string), id, cast(id as string), id, cast(id as
string), " +
- "'2023-05-09' from range(10000000)")
- }
+ nativeWrite2 {
+ format =>
+ val table_name = s"tmp_123_$format"
+ val create_sql =
+ s"""create table tmp_123_$format(
+ |x1 string, x2 bigint,x3 string, x4 bigint, x5 string )
+ |partitioned by (day date) stored as $format""".stripMargin
+ val insert_sql =
+ s"""insert into tmp_123_$format partition(day)
+ |select cast(id as string), id, cast(id as string),
+ | id, cast(id as string), '2023-05-09'
+ |from range(10000000)""".stripMargin
+ (table_name, create_sql, insert_sql)
}
}
test("test bucketed by constant") {
- withSQLConf(("spark.gluten.sql.native.writer.enabled", "true")) {
-
- for (format <- formats) {
+ nativeWrite {
+ format =>
val table_name = table_name_template.format(format)
spark.sql(s"drop table IF EXISTS $table_name")
@@ -815,15 +739,13 @@ class GlutenClickHouseNativeWriteTableSuite
.bucketBy(2, "p")
.saveAsTable(table_name)
- val ret = spark.sql("select count(*) from " +
table_name).collect().apply(0).apply(0)
- }
+ assertResult(10000000)(spark.table(table_name).count())
}
}
test("test consecutive null values being partitioned") {
- withSQLConf(("spark.gluten.sql.native.writer.enabled", "true")) {
-
- for (format <- formats) {
+ nativeWrite {
+ format =>
val table_name = table_name_template.format(format)
spark.sql(s"drop table IF EXISTS $table_name")
@@ -835,14 +757,13 @@ class GlutenClickHouseNativeWriteTableSuite
.partitionBy("p")
.saveAsTable(table_name)
- val ret = spark.sql("select count(*) from " +
table_name).collect().apply(0).apply(0)
- }
+ assertResult(30000)(spark.table(table_name).count())
}
}
test("test consecutive null values being bucketed") {
- withSQLConf(("spark.gluten.sql.native.writer.enabled", "true")) {
- for (format <- formats) {
+ nativeWrite {
+ format =>
val table_name = table_name_template.format(format)
spark.sql(s"drop table IF EXISTS $table_name")
@@ -854,78 +775,79 @@ class GlutenClickHouseNativeWriteTableSuite
.bucketBy(2, "p")
.saveAsTable(table_name)
- val ret = spark.sql("select count(*) from " +
table_name).collect().apply(0).apply(0)
- }
+ assertResult(30000)(spark.table(table_name).count())
}
}
test("test native write with empty dataset") {
- withSQLConf(("spark.gluten.sql.native.writer.enabled", "true")) {
- for (format <- formats) {
+ nativeWrite2(
+ format => {
val table_name = "t_" + format
- spark.sql(s"drop table IF EXISTS $table_name")
- spark.sql(s"create table $table_name (id int, str string) stored as
$format")
- spark.sql(
- s"insert into $table_name select id, cast(id as string) from
range(10)" +
- " where id > 100")
+ (
+ table_name,
+ s"create table $table_name (id int, str string) stored as $format",
+ s"insert into $table_name select id, cast(id as string) from
range(10) where id > 100"
+ )
+ },
+ (table_name, _, _) => {
+ assertResult(0)(spark.table(table_name).count())
}
- }
+ )
}
test("test native write with union") {
- withSQLConf(("spark.gluten.sql.native.writer.enabled", "true")) {
- for (format <- formats) {
+ nativeWrite {
+ format =>
val table_name = "t_" + format
- spark.sql(s"drop table IF EXISTS $table_name")
- spark.sql(s"create table $table_name (id int, str string) stored as
$format")
- spark.sql(
- s"insert overwrite table $table_name " +
- "select id, cast(id as string) from range(10) union all " +
- "select 10, '10' from range(10)")
- spark.sql(
- s"insert overwrite table $table_name " +
- "select id, cast(id as string) from range(10) union all " +
- "select 10, cast(id as string) from range(10)")
-
- }
+ withDestinationTable(
+ table_name,
+ s"create table $table_name (id int, str string) stored as $format") {
+ checkNativeWrite(
+ s"insert overwrite table $table_name " +
+ "select id, cast(id as string) from range(10) union all " +
+ "select 10, '10' from range(10)",
+ checkNative = true)
+ checkNativeWrite(
+ s"insert overwrite table $table_name " +
+ "select id, cast(id as string) from range(10) union all " +
+ "select 10, cast(id as string) from range(10)",
+ checkNative = true
+ )
+ }
}
}
test("test native write and non-native read consistency") {
- withSQLConf(("spark.gluten.sql.native.writer.enabled", "true")) {
- for (format <- formats) {
- val table_name = "t_" + format
- spark.sql(s"drop table IF EXISTS $table_name")
- spark.sql(s"create table $table_name (id int, name string, info
char(4)) stored as $format")
- spark.sql(
- s"insert overwrite table $table_name " +
- "select id, cast(id as string), concat('aaa', cast(id as string))
from range(10)")
+ nativeWrite2(
+ {
+ format =>
+ val table_name = "t_" + format
+ (
+ table_name,
+ s"create table $table_name (id int, name string, info char(4))
stored as $format",
+ s"insert overwrite table $table_name " +
+ "select id, cast(id as string), concat('aaa', cast(id as
string)) from range(10)"
+ )
+ },
+ (table_name, _, _) =>
compareResultsAgainstVanillaSpark(
s"select * from $table_name",
compareResult = true,
_ => {})
- }
- }
+ )
}
test("GLUTEN-4316: fix crash on dynamic partition inserting") {
- withSQLConf(
- ("spark.gluten.sql.native.writer.enabled", "true"),
- (GlutenConfig.GLUTEN_ENABLED.key, "true")) {
- formats.foreach(
- format => {
- val tbl = "t_" + format
- spark.sql(s"drop table IF EXISTS $tbl")
- val sql1 =
- s"create table $tbl(a int, b map<string, string>, c
struct<d:string, e:string>) " +
- s"partitioned by (day string) stored as $format"
- val sql2 = s"insert overwrite $tbl partition (day) " +
- s"select id as a,
str_to_map(concat('t1:','a','&t2:','b'),'&',':'), " +
- s"struct('1', null) as c, '2024-01-08' as day from range(10)"
- spark.sql(sql1)
- spark.sql(sql2)
- })
+ nativeWrite2 {
+ format =>
+ val tbl = "t_" + format
+ val sql1 =
+ s"create table $tbl(a int, b map<string, string>, c struct<d:string,
e:string>) " +
+ s"partitioned by (day string) stored as $format"
+ val sql2 = s"insert overwrite $tbl partition (day) " +
+ s"select id as a, str_to_map(concat('t1:','a','&t2:','b'),'&',':'),
" +
+ s"struct('1', null) as c, '2024-01-08' as day from range(10)"
+ (tbl, sql1, sql2)
}
}
-
}
diff --git
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/metrics/GlutenClickHouseTPCHMetricsSuite.scala
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/metrics/GlutenClickHouseTPCHMetricsSuite.scala
index 09fa3ff10..1b3df8166 100644
---
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/metrics/GlutenClickHouseTPCHMetricsSuite.scala
+++
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/metrics/GlutenClickHouseTPCHMetricsSuite.scala
@@ -46,7 +46,7 @@ class GlutenClickHouseTPCHMetricsSuite extends
GlutenClickHouseTPCHAbstractSuite
.set("spark.io.compression.codec", "LZ4")
.set("spark.sql.shuffle.partitions", "1")
.set("spark.sql.autoBroadcastJoinThreshold", "10MB")
- .set("spark.gluten.sql.columnar.backend.ch.runtime_config.logger.level",
"DEBUG")
+ //
.set("spark.gluten.sql.columnar.backend.ch.runtime_config.logger.level",
"DEBUG")
.set(
"spark.gluten.sql.columnar.backend.ch.runtime_settings.input_format_parquet_max_block_size",
s"$parquetMaxBlockSize")
diff --git
a/backends-clickhouse/src/test/scala/org/apache/spark/gluten/NativeWriteChecker.scala
b/backends-clickhouse/src/test/scala/org/apache/spark/gluten/NativeWriteChecker.scala
new file mode 100644
index 000000000..79616d52d
--- /dev/null
+++
b/backends-clickhouse/src/test/scala/org/apache/spark/gluten/NativeWriteChecker.scala
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.gluten
+
+import org.apache.gluten.execution.GlutenClickHouseWholeStageTransformerSuite
+
+import org.apache.spark.sql.execution.QueryExecution
+import org.apache.spark.sql.execution.datasources.FakeRowAdaptor
+import org.apache.spark.sql.util.QueryExecutionListener
+
+trait NativeWriteChecker extends GlutenClickHouseWholeStageTransformerSuite {
+
+ def checkNativeWrite(sqlStr: String, checkNative: Boolean): Unit = {
+ var nativeUsed = false
+
+ val queryListener = new QueryExecutionListener {
+ override def onFailure(f: String, qe: QueryExecution, e: Exception):
Unit = {}
+ override def onSuccess(funcName: String, qe: QueryExecution, duration:
Long): Unit = {
+ if (!nativeUsed) {
+ nativeUsed = if (isSparkVersionGE("3.4")) {
+ false
+ } else {
+ qe.executedPlan.find(_.isInstanceOf[FakeRowAdaptor]).isDefined
+ }
+ }
+ }
+ }
+
+ try {
+ spark.listenerManager.register(queryListener)
+ spark.sql(sqlStr)
+ spark.sparkContext.listenerBus.waitUntilEmpty()
+ assertResult(checkNative)(nativeUsed)
+ } finally {
+ spark.listenerManager.unregister(queryListener)
+ }
+ }
+}
diff --git
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
index 1f868c4c2..7b8d523a6 100644
---
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
+++
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
@@ -827,15 +827,6 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi {
buf.result
}
- /**
- * Generate extended columnar post-rules.
- *
- * @return
- */
- override def genExtendedColumnarPostRules(): List[SparkSession =>
Rule[SparkPlan]] = {
- SparkShimLoader.getSparkShims.getExtendedColumnarPostRules() ::: List()
- }
-
override def genInjectPostHocResolutionRules(): List[SparkSession =>
Rule[LogicalPlan]] = {
List(ArrowConvertorRule)
}
diff --git a/cpp-ch/local-engine/Common/CHUtil.cpp
b/cpp-ch/local-engine/Common/CHUtil.cpp
index 94cd38003..ae3f6dbd5 100644
--- a/cpp-ch/local-engine/Common/CHUtil.cpp
+++ b/cpp-ch/local-engine/Common/CHUtil.cpp
@@ -77,6 +77,7 @@ namespace ErrorCodes
{
extern const int BAD_ARGUMENTS;
extern const int UNKNOWN_TYPE;
+extern const int CANNOT_PARSE_PROTOBUF_SCHEMA;
}
}
@@ -466,17 +467,17 @@ String
QueryPipelineUtil::explainPipeline(DB::QueryPipeline & pipeline)
using namespace DB;
-std::map<std::string, std::string>
BackendInitializerUtil::getBackendConfMap(std::string * plan)
+std::map<std::string, std::string>
BackendInitializerUtil::getBackendConfMap(const std::string & plan)
{
std::map<std::string, std::string> ch_backend_conf;
- if (plan == nullptr)
+ if (plan.empty())
return ch_backend_conf;
/// Parse backend configs from plan extensions
do
{
auto plan_ptr = std::make_unique<substrait::Plan>();
- auto success = plan_ptr->ParseFromString(*plan);
+ auto success = plan_ptr->ParseFromString(plan);
if (!success)
break;
@@ -841,14 +842,8 @@ void
BackendInitializerUtil::initCompiledExpressionCache(DB::Context::Configurat
#endif
}
-void BackendInitializerUtil::init_json(std::string * plan_json)
-{
- auto plan_ptr = std::make_unique<substrait::Plan>();
- google::protobuf::util::JsonStringToMessage(plan_json->c_str(),
plan_ptr.get());
- return init(new String(plan_ptr->SerializeAsString()));
-}
-void BackendInitializerUtil::init(std::string * plan)
+void BackendInitializerUtil::init(const std::string & plan)
{
std::map<std::string, std::string> backend_conf_map =
getBackendConfMap(plan);
DB::Context::ConfigurationPtr config = initConfig(backend_conf_map);
@@ -906,7 +901,7 @@ void BackendInitializerUtil::init(std::string * plan)
});
}
-void BackendInitializerUtil::updateConfig(const DB::ContextMutablePtr &
context, std::string * plan)
+void BackendInitializerUtil::updateConfig(const DB::ContextMutablePtr &
context, const std::string & plan)
{
std::map<std::string, std::string> backend_conf_map =
getBackendConfMap(plan);
diff --git a/cpp-ch/local-engine/Common/CHUtil.h
b/cpp-ch/local-engine/Common/CHUtil.h
index 94e0f0168..245d7b3d1 100644
--- a/cpp-ch/local-engine/Common/CHUtil.h
+++ b/cpp-ch/local-engine/Common/CHUtil.h
@@ -137,9 +137,8 @@ public:
/// Initialize two kinds of resources
/// 1. global level resources like global_context/shared_context, notice
that they can only be initialized once in process lifetime
/// 2. session level resources like settings/configs, they can be
initialized multiple times following the lifetime of executor/driver
- static void init(std::string * plan);
- static void init_json(std::string * plan_json);
- static void updateConfig(const DB::ContextMutablePtr &, std::string *);
+ static void init(const std::string & plan);
+ static void updateConfig(const DB::ContextMutablePtr &, const std::string
&);
// use excel text parser
@@ -196,7 +195,7 @@ private:
static void updateNewSettings(const DB::ContextMutablePtr &, const
DB::Settings &);
- static std::map<std::string, std::string> getBackendConfMap(std::string *
plan);
+ static std::map<std::string, std::string> getBackendConfMap(const
std::string & plan);
inline static std::once_flag init_flag;
inline static Poco::Logger * logger;
@@ -283,10 +282,7 @@ public:
return deq.empty();
}
- std::deque<T> unsafeGet()
- {
- return deq;
- }
+ std::deque<T> unsafeGet() { return deq; }
private:
std::deque<T> deq;
diff --git a/cpp-ch/local-engine/Parser/CHColumnToSparkRow.cpp
b/cpp-ch/local-engine/Parser/CHColumnToSparkRow.cpp
index 2b4eb824a..5bb66e4b3 100644
--- a/cpp-ch/local-engine/Parser/CHColumnToSparkRow.cpp
+++ b/cpp-ch/local-engine/Parser/CHColumnToSparkRow.cpp
@@ -453,7 +453,7 @@ std::unique_ptr<SparkRowInfo>
CHColumnToSparkRow::convertCHColumnToSparkRow(cons
if (!block.columns())
throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "A block with empty
columns");
std::unique_ptr<SparkRowInfo> spark_row_info =
std::make_unique<SparkRowInfo>(block, masks);
- spark_row_info->setBufferAddress(reinterpret_cast<char
*>(alloc(spark_row_info->getTotalBytes(), 64)));
+ spark_row_info->setBufferAddress(static_cast<char
*>(alloc(spark_row_info->getTotalBytes(), 64)));
//
spark_row_info->setBufferAddress(alignedAlloc(spark_row_info->getTotalBytes(),
64));
memset(spark_row_info->getBufferAddress(), 0,
spark_row_info->getTotalBytes());
for (auto col_idx = 0; col_idx < spark_row_info->getNumCols(); col_idx++)
diff --git a/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp
b/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp
index 70db692c8..3115950cd 100644
--- a/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp
+++ b/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp
@@ -87,14 +87,14 @@ namespace DB
{
namespace ErrorCodes
{
- extern const int LOGICAL_ERROR;
- extern const int UNKNOWN_TYPE;
- extern const int BAD_ARGUMENTS;
- extern const int NO_SUCH_DATA_PART;
- extern const int UNKNOWN_FUNCTION;
- extern const int CANNOT_PARSE_PROTOBUF_SCHEMA;
- extern const int ILLEGAL_TYPE_OF_ARGUMENT;
- extern const int INVALID_JOIN_ON_EXPRESSION;
+extern const int LOGICAL_ERROR;
+extern const int UNKNOWN_TYPE;
+extern const int BAD_ARGUMENTS;
+extern const int NO_SUCH_DATA_PART;
+extern const int UNKNOWN_FUNCTION;
+extern const int CANNOT_PARSE_PROTOBUF_SCHEMA;
+extern const int ILLEGAL_TYPE_OF_ARGUMENT;
+extern const int INVALID_JOIN_ON_EXPRESSION;
}
}
@@ -144,16 +144,13 @@ void SerializedPlanParser::parseExtensions(
if (extension.has_extension_function())
{
function_mapping.emplace(
-
std::to_string(extension.extension_function().function_anchor()),
- extension.extension_function().name());
+
std::to_string(extension.extension_function().function_anchor()),
extension.extension_function().name());
}
}
}
std::shared_ptr<ActionsDAG> SerializedPlanParser::expressionsToActionsDAG(
- const std::vector<substrait::Expression> & expressions,
- const Block & header,
- const Block & read_schema)
+ const std::vector<substrait::Expression> & expressions, const Block &
header, const Block & read_schema)
{
auto actions_dag =
std::make_shared<ActionsDAG>(blockToNameAndTypeList(header));
NamesWithAliases required_columns;
@@ -259,8 +256,8 @@ std::string getDecimalFunction(const
substrait::Type_Decimal & decimal, bool nul
bool SerializedPlanParser::isReadRelFromJava(const substrait::ReadRel & rel)
{
- return rel.has_local_files() && rel.local_files().items().size() == 1 &&
rel.local_files().items().at(0).uri_file().starts_with(
- "iterator");
+ return rel.has_local_files() && rel.local_files().items().size() == 1
+ && rel.local_files().items().at(0).uri_file().starts_with("iterator");
}
bool SerializedPlanParser::isReadFromMergeTree(const substrait::ReadRel & rel)
@@ -380,13 +377,13 @@ DataTypePtr wrapNullableType(bool nullable, DataTypePtr
nested_type)
return nested_type;
}
-QueryPlanPtr SerializedPlanParser::parse(std::unique_ptr<substrait::Plan> plan)
+QueryPlanPtr SerializedPlanParser::parse(const substrait::Plan & plan)
{
- logDebugMessage(*plan, "substrait plan");
- parseExtensions(plan->extensions());
- if (plan->relations_size() == 1)
+ logDebugMessage(plan, "substrait plan");
+ parseExtensions(plan.extensions());
+ if (plan.relations_size() == 1)
{
- auto root_rel = plan->relations().at(0);
+ auto root_rel = plan.relations().at(0);
if (!root_rel.has_root())
{
throw Exception(ErrorCodes::BAD_ARGUMENTS, "must have root rel!");
@@ -587,9 +584,7 @@ SerializedPlanParser::getFunctionName(const std::string &
function_signature, co
{
if (args.size() != 2)
throw Exception(
- ErrorCodes::BAD_ARGUMENTS,
- "Spark function extract requires two args, function:{}",
- function.ShortDebugString());
+ ErrorCodes::BAD_ARGUMENTS, "Spark function extract requires
two args, function:{}", function.ShortDebugString());
// Get the first arg: field
const auto & extract_field = args.at(0);
@@ -705,9 +700,7 @@ void SerializedPlanParser::parseArrayJoinArguments(
/// The argument number of arrayJoin(converted from Spark
explode/posexplode) should be 1
if (scalar_function.arguments_size() != 1)
throw Exception(
- ErrorCodes::BAD_ARGUMENTS,
- "Argument number of arrayJoin should be 1 instead of {}",
- scalar_function.arguments_size());
+ ErrorCodes::BAD_ARGUMENTS, "Argument number of arrayJoin should be
1 instead of {}", scalar_function.arguments_size());
auto function_name_copy = function_name;
parseFunctionArguments(actions_dag, parsed_args, function_name_copy,
scalar_function);
@@ -746,11 +739,7 @@ void SerializedPlanParser::parseArrayJoinArguments(
}
ActionsDAG::NodeRawConstPtrs SerializedPlanParser::parseArrayJoinWithDAG(
- const substrait::Expression & rel,
- std::vector<String> & result_names,
- ActionsDAGPtr actions_dag,
- bool keep_result,
- bool position)
+ const substrait::Expression & rel, std::vector<String> & result_names,
ActionsDAGPtr actions_dag, bool keep_result, bool position)
{
if (!rel.has_scalar_function())
throw Exception(ErrorCodes::BAD_ARGUMENTS, "The root of expression
should be a scalar function:\n {}", rel.DebugString());
@@ -774,7 +763,8 @@ ActionsDAG::NodeRawConstPtrs
SerializedPlanParser::parseArrayJoinWithDAG(
auto tuple_element_builder =
FunctionFactory::instance().get("sparkTupleElement", context);
auto tuple_index_type = std::make_shared<DataTypeUInt32>();
- auto add_tuple_element = [&](const ActionsDAG::Node * tuple_node, size_t
i) -> const ActionsDAG::Node * {
+ auto add_tuple_element = [&](const ActionsDAG::Node * tuple_node, size_t
i) -> const ActionsDAG::Node *
+ {
ColumnWithTypeAndName index_col(tuple_index_type->createColumnConst(1,
i), tuple_index_type, getUniqueName(std::to_string(i)));
const auto * index_node =
&actions_dag->addColumn(std::move(index_col));
auto result_name = "sparkTupleElement(" + tuple_node->result_name + ",
" + index_node->result_name + ")";
@@ -866,10 +856,7 @@ ActionsDAG::NodeRawConstPtrs
SerializedPlanParser::parseArrayJoinWithDAG(
}
const ActionsDAG::Node * SerializedPlanParser::parseFunctionWithDAG(
- const substrait::Expression & rel,
- std::string & result_name,
- ActionsDAGPtr actions_dag,
- bool keep_result)
+ const substrait::Expression & rel, std::string & result_name,
ActionsDAGPtr actions_dag, bool keep_result)
{
if (!rel.has_scalar_function())
throw Exception(ErrorCodes::BAD_ARGUMENTS, "the root of expression
should be a scalar function:\n {}", rel.DebugString());
@@ -884,10 +871,7 @@ const ActionsDAG::Node *
SerializedPlanParser::parseFunctionWithDAG(
if (auto func_parser = FunctionParserFactory::instance().tryGet(func_name,
this))
{
LOG_DEBUG(
- &Poco::Logger::get("SerializedPlanParser"),
- "parse function {} by function parser: {}",
- func_name,
- func_parser->getName());
+ &Poco::Logger::get("SerializedPlanParser"), "parse function {} by
function parser: {}", func_name, func_parser->getName());
const auto * result_node = func_parser->parse(scalar_function,
actions_dag);
if (keep_result)
actions_dag->addOrReplaceInOutputs(*result_node);
@@ -956,12 +940,10 @@ const ActionsDAG::Node *
SerializedPlanParser::parseFunctionWithDAG(
UInt32 precision =
rel.scalar_function().output_type().decimal().precision();
UInt32 scale = rel.scalar_function().output_type().decimal().scale();
auto uint32_type = std::make_shared<DataTypeUInt32>();
- new_args.emplace_back(
- &actions_dag->addColumn(
- ColumnWithTypeAndName(uint32_type->createColumnConst(1,
precision), uint32_type, getUniqueName(toString(precision)))));
- new_args.emplace_back(
- &actions_dag->addColumn(
- ColumnWithTypeAndName(uint32_type->createColumnConst(1,
scale), uint32_type, getUniqueName(toString(scale)))));
+ new_args.emplace_back(&actions_dag->addColumn(
+ ColumnWithTypeAndName(uint32_type->createColumnConst(1,
precision), uint32_type, getUniqueName(toString(precision)))));
+ new_args.emplace_back(&actions_dag->addColumn(
+ ColumnWithTypeAndName(uint32_type->createColumnConst(1, scale),
uint32_type, getUniqueName(toString(scale)))));
args = std::move(new_args);
}
else if (startsWith(function_signature, "make_decimal:"))
@@ -976,12 +958,10 @@ const ActionsDAG::Node *
SerializedPlanParser::parseFunctionWithDAG(
UInt32 precision =
rel.scalar_function().output_type().decimal().precision();
UInt32 scale = rel.scalar_function().output_type().decimal().scale();
auto uint32_type = std::make_shared<DataTypeUInt32>();
- new_args.emplace_back(
- &actions_dag->addColumn(
- ColumnWithTypeAndName(uint32_type->createColumnConst(1,
precision), uint32_type, getUniqueName(toString(precision)))));
- new_args.emplace_back(
- &actions_dag->addColumn(
- ColumnWithTypeAndName(uint32_type->createColumnConst(1,
scale), uint32_type, getUniqueName(toString(scale)))));
+ new_args.emplace_back(&actions_dag->addColumn(
+ ColumnWithTypeAndName(uint32_type->createColumnConst(1,
precision), uint32_type, getUniqueName(toString(precision)))));
+ new_args.emplace_back(&actions_dag->addColumn(
+ ColumnWithTypeAndName(uint32_type->createColumnConst(1, scale),
uint32_type, getUniqueName(toString(scale)))));
args = std::move(new_args);
}
@@ -999,9 +979,8 @@ const ActionsDAG::Node *
SerializedPlanParser::parseFunctionWithDAG(
actions_dag,
function_node,
// as stated in isTypeMatched, currently we don't change
nullability of the result type
- function_node->result_type->isNullable()
- ? local_engine::wrapNullableType(true, result_type)->getName()
- : local_engine::removeNullable(result_type)->getName(),
+ function_node->result_type->isNullable() ?
local_engine::wrapNullableType(true, result_type)->getName()
+ :
local_engine::removeNullable(result_type)->getName(),
function_node->result_name,
CastType::accurateOrNull);
}
@@ -1011,9 +990,8 @@ const ActionsDAG::Node *
SerializedPlanParser::parseFunctionWithDAG(
actions_dag,
function_node,
// as stated in isTypeMatched, currently we don't change
nullability of the result type
- function_node->result_type->isNullable()
- ? local_engine::wrapNullableType(true, result_type)->getName()
- : local_engine::removeNullable(result_type)->getName(),
+ function_node->result_type->isNullable() ?
local_engine::wrapNullableType(true, result_type)->getName()
+ :
local_engine::removeNullable(result_type)->getName(),
function_node->result_name);
}
}
@@ -1159,9 +1137,7 @@ void SerializedPlanParser::parseFunctionArgument(
}
const ActionsDAG::Node * SerializedPlanParser::parseFunctionArgument(
- ActionsDAGPtr & actions_dag,
- const std::string & function_name,
- const substrait::FunctionArgument & arg)
+ ActionsDAGPtr & actions_dag, const std::string & function_name, const
substrait::FunctionArgument & arg)
{
const ActionsDAG::Node * res;
if (arg.value().has_scalar_function())
@@ -1189,11 +1165,8 @@ std::pair<DataTypePtr, Field>
SerializedPlanParser::convertStructFieldType(const
}
auto type_id = type->getTypeId();
- if (type_id == TypeIndex::UInt8 || type_id == TypeIndex::UInt16 || type_id
== TypeIndex::UInt32
- || type_id == TypeIndex::UInt64)
- {
+ if (type_id == TypeIndex::UInt8 || type_id == TypeIndex::UInt16 || type_id
== TypeIndex::UInt32 || type_id == TypeIndex::UInt64)
return {type, field};
- }
UINT_CONVERT(type, field, Int8)
UINT_CONVERT(type, field, Int16)
UINT_CONVERT(type, field, Int32)
@@ -1203,11 +1176,7 @@ std::pair<DataTypePtr, Field>
SerializedPlanParser::convertStructFieldType(const
}
ActionsDAGPtr SerializedPlanParser::parseFunction(
- const Block & header,
- const substrait::Expression & rel,
- std::string & result_name,
- ActionsDAGPtr actions_dag,
- bool keep_result)
+ const Block & header, const substrait::Expression & rel, std::string &
result_name, ActionsDAGPtr actions_dag, bool keep_result)
{
if (!actions_dag)
actions_dag =
std::make_shared<ActionsDAG>(blockToNameAndTypeList(header));
@@ -1217,11 +1186,7 @@ ActionsDAGPtr SerializedPlanParser::parseFunction(
}
ActionsDAGPtr SerializedPlanParser::parseFunctionOrExpression(
- const Block & header,
- const substrait::Expression & rel,
- std::string & result_name,
- ActionsDAGPtr actions_dag,
- bool keep_result)
+ const Block & header, const substrait::Expression & rel, std::string &
result_name, ActionsDAGPtr actions_dag, bool keep_result)
{
if (!actions_dag)
actions_dag =
std::make_shared<ActionsDAG>(blockToNameAndTypeList(header));
@@ -1303,7 +1268,8 @@ ActionsDAGPtr SerializedPlanParser::parseJsonTuple(
= &actions_dag->addFunction(json_extract_builder, {json_expr_node,
extract_expr_node}, json_extract_result_name);
auto tuple_element_builder =
FunctionFactory::instance().get("sparkTupleElement", context);
auto tuple_index_type = std::make_shared<DataTypeUInt32>();
- auto add_tuple_element = [&](const ActionsDAG::Node * tuple_node, size_t
i) -> const ActionsDAG::Node * {
+ auto add_tuple_element = [&](const ActionsDAG::Node * tuple_node, size_t
i) -> const ActionsDAG::Node *
+ {
ColumnWithTypeAndName index_col(tuple_index_type->createColumnConst(1,
i), tuple_index_type, getUniqueName(std::to_string(i)));
const auto * index_node =
&actions_dag->addColumn(std::move(index_col));
auto result_name = "sparkTupleElement(" + tuple_node->result_name + ",
" + index_node->result_name + ")";
@@ -1528,9 +1494,7 @@ std::pair<DataTypePtr, Field>
SerializedPlanParser::parseLiteral(const substrait
}
default: {
throw Exception(
- ErrorCodes::UNKNOWN_TYPE,
- "Unsupported spark literal type {}",
- magic_enum::enum_name(literal.literal_type_case()));
+ ErrorCodes::UNKNOWN_TYPE, "Unsupported spark literal type {}",
magic_enum::enum_name(literal.literal_type_case()));
}
}
return std::make_pair(std::move(type), std::move(field));
@@ -1732,8 +1696,7 @@ substrait::ReadRel::ExtensionTable
SerializedPlanParser::parseExtensionTable(con
{
substrait::ReadRel::ExtensionTable extension_table;
google::protobuf::io::CodedInputStream coded_in(
- reinterpret_cast<const uint8_t *>(split_info.data()),
- static_cast<int>(split_info.size()));
+ reinterpret_cast<const uint8_t *>(split_info.data()),
static_cast<int>(split_info.size()));
coded_in.SetRecursionLimit(100000);
auto ok = extension_table.ParseFromCodedStream(&coded_in);
@@ -1747,8 +1710,7 @@ substrait::ReadRel::LocalFiles
SerializedPlanParser::parseLocalFiles(const std::
{
substrait::ReadRel::LocalFiles local_files;
google::protobuf::io::CodedInputStream coded_in(
- reinterpret_cast<const uint8_t *>(split_info.data()),
- static_cast<int>(split_info.size()));
+ reinterpret_cast<const uint8_t *>(split_info.data()),
static_cast<int>(split_info.size()));
coded_in.SetRecursionLimit(100000);
auto ok = local_files.ParseFromCodedStream(&coded_in);
@@ -1758,10 +1720,44 @@ substrait::ReadRel::LocalFiles
SerializedPlanParser::parseLocalFiles(const std::
return local_files;
}
+std::unique_ptr<LocalExecutor>
SerializedPlanParser::createExecutor(DB::QueryPlanPtr query_plan)
+{
+ Stopwatch stopwatch;
+ auto * logger = &Poco::Logger::get("SerializedPlanParser");
+ const Settings & settings = context->getSettingsRef();
+
+ QueryPriorities priorities;
+ auto query_status = std::make_shared<QueryStatus>(
+ context,
+ "",
+ context->getClientInfo(),
+ priorities.insert(static_cast<int>(settings.priority)),
+ CurrentThread::getGroup(),
+ IAST::QueryKind::Select,
+ settings,
+ 0);
+
+ QueryPlanOptimizationSettings optimization_settings{.optimize_plan =
settings.query_plan_enable_optimizations};
+ auto pipeline_builder = query_plan->buildQueryPipeline(
+ optimization_settings,
+ BuildQueryPipelineSettings{
+ .actions_settings
+ = ExpressionActionsSettings{.can_compile_expressions = true,
.min_count_to_compile_expression = 3, .compile_expressions =
CompileExpressions::yes},
+ .process_list_element = query_status});
+ QueryPipeline pipeline =
QueryPipelineBuilder::getPipeline(std::move(*pipeline_builder));
+ LOG_INFO(logger, "build pipeline {} ms", stopwatch.elapsedMicroseconds() /
1000.0);
+
+ LOG_DEBUG(
+ logger, "clickhouse plan [optimization={}]:\n{}",
settings.query_plan_enable_optimizations, PlanUtil::explainPlan(*query_plan));
+ LOG_DEBUG(logger, "clickhouse pipeline:\n{}",
QueryPipelineUtil::explainPipeline(pipeline));
+
+ return std::make_unique<LocalExecutor>(
+ context, std::move(query_plan), std::move(pipeline),
query_plan->getCurrentDataStream().header.cloneEmpty());
+}
-QueryPlanPtr SerializedPlanParser::parse(const std::string & plan)
+QueryPlanPtr SerializedPlanParser::parse(const std::string_view & plan)
{
- auto plan_ptr = std::make_unique<substrait::Plan>();
+ substrait::Plan s_plan;
///
https://stackoverflow.com/questions/52028583/getting-error-parsing-protobuf-data
/// Parsing may fail when the number of recursive layers is large.
/// Here, set a limit large enough to avoid this problem.
@@ -1769,11 +1765,10 @@ QueryPlanPtr SerializedPlanParser::parse(const
std::string & plan)
google::protobuf::io::CodedInputStream coded_in(reinterpret_cast<const
uint8_t *>(plan.data()), static_cast<int>(plan.size()));
coded_in.SetRecursionLimit(100000);
- auto ok = plan_ptr->ParseFromCodedStream(&coded_in);
- if (!ok)
+ if (!s_plan.ParseFromCodedStream(&coded_in))
throw Exception(ErrorCodes::CANNOT_PARSE_PROTOBUF_SCHEMA, "Parse
substrait::Plan from string failed");
- auto res = parse(std::move(plan_ptr));
+ auto res = parse(s_plan);
#ifndef NDEBUG
PlanUtil::checkOuputType(*res);
@@ -1788,17 +1783,16 @@ QueryPlanPtr SerializedPlanParser::parse(const
std::string & plan)
return res;
}
-QueryPlanPtr SerializedPlanParser::parseJson(const std::string & json_plan)
+QueryPlanPtr SerializedPlanParser::parseJson(const std::string_view &
json_plan)
{
- auto plan_ptr = std::make_unique<substrait::Plan>();
- auto s =
google::protobuf::util::JsonStringToMessage(absl::string_view(json_plan),
plan_ptr.get());
+ substrait::Plan plan;
+ auto s = google::protobuf::util::JsonStringToMessage(json_plan, &plan);
if (!s.ok())
throw Exception(ErrorCodes::CANNOT_PARSE_PROTOBUF_SCHEMA, "Parse
substrait::Plan from json string failed: {}", s.ToString());
- return parse(std::move(plan_ptr));
+ return parse(plan);
}
-SerializedPlanParser::SerializedPlanParser(const ContextPtr & context_)
- : context(context_)
+SerializedPlanParser::SerializedPlanParser(const ContextPtr & context_) :
context(context_)
{
}
@@ -1807,13 +1801,10 @@ ContextMutablePtr SerializedPlanParser::global_context
= nullptr;
Context::ConfigurationPtr SerializedPlanParser::config = nullptr;
void SerializedPlanParser::collectJoinKeys(
- const substrait::Expression & condition,
- std::vector<std::pair<int32_t, int32_t>> & join_keys,
- int32_t right_key_start)
+ const substrait::Expression & condition, std::vector<std::pair<int32_t,
int32_t>> & join_keys, int32_t right_key_start)
{
auto condition_name = getFunctionName(
-
function_mapping.at(std::to_string(condition.scalar_function().function_reference())),
- condition.scalar_function());
+
function_mapping.at(std::to_string(condition.scalar_function().function_reference())),
condition.scalar_function());
if (condition_name == "and")
{
collectJoinKeys(condition.scalar_function().arguments(0).value(),
join_keys, right_key_start);
@@ -1863,8 +1854,8 @@ ASTPtr ASTParser::parseToAST(const Names & names, const
substrait::Expression &
auto substrait_name = function_signature.substr(0,
function_signature.find(':'));
auto func_parser =
FunctionParserFactory::instance().tryGet(substrait_name, plan_parser);
- String function_name = func_parser ? func_parser->getName()
- :
SerializedPlanParser::getFunctionName(function_signature, scalar_function);
+ String function_name
+ = func_parser ? func_parser->getName() :
SerializedPlanParser::getFunctionName(function_signature, scalar_function);
ASTs ast_args;
parseFunctionArgumentsToAST(names, scalar_function, ast_args);
@@ -1876,9 +1867,7 @@ ASTPtr ASTParser::parseToAST(const Names & names, const
substrait::Expression &
}
void ASTParser::parseFunctionArgumentsToAST(
- const Names & names,
- const substrait::Expression_ScalarFunction & scalar_function,
- ASTs & ast_args)
+ const Names & names, const substrait::Expression_ScalarFunction &
scalar_function, ASTs & ast_args)
{
const auto & args = scalar_function.arguments();
@@ -2021,12 +2010,12 @@ ASTPtr ASTParser::parseArgumentToAST(const Names &
names, const substrait::Expre
}
}
-void SerializedPlanParser::removeNullableForRequiredColumns(const
std::set<String> & require_columns, ActionsDAGPtr actions_dag)
+void SerializedPlanParser::removeNullableForRequiredColumns(
+ const std::set<String> & require_columns, const ActionsDAGPtr &
actions_dag) const
{
for (const auto & item : require_columns)
{
- const auto * require_node = actions_dag->tryFindInOutputs(item);
- if (require_node)
+ if (const auto * require_node = actions_dag->tryFindInOutputs(item))
{
auto function_builder =
FunctionFactory::instance().get("assumeNotNull", context);
ActionsDAG::NodeRawConstPtrs args = {require_node};
@@ -2037,9 +2026,7 @@ void
SerializedPlanParser::removeNullableForRequiredColumns(const std::set<Strin
}
void SerializedPlanParser::wrapNullable(
- const std::vector<String> & columns,
- ActionsDAGPtr actions_dag,
- std::map<std::string, std::string> & nullable_measure_names)
+ const std::vector<String> & columns, ActionsDAGPtr actions_dag,
std::map<std::string, std::string> & nullable_measure_names)
{
for (const auto & item : columns)
{
@@ -2092,86 +2079,23 @@ LocalExecutor::~LocalExecutor()
}
}
-
-void LocalExecutor::execute(QueryPlanPtr query_plan)
-{
- Stopwatch stopwatch;
-
- const Settings & settings = context->getSettingsRef();
- current_query_plan = std::move(query_plan);
- auto * logger = &Poco::Logger::get("LocalExecutor");
-
- QueryPriorities priorities;
- auto query_status = std::make_shared<QueryStatus>(
- context,
- "",
- context->getClientInfo(),
- priorities.insert(static_cast<int>(settings.priority)),
- CurrentThread::getGroup(),
- IAST::QueryKind::Select,
- settings,
- 0);
-
- QueryPlanOptimizationSettings optimization_settings{.optimize_plan =
settings.query_plan_enable_optimizations};
- auto pipeline_builder = current_query_plan->buildQueryPipeline(
- optimization_settings,
- BuildQueryPipelineSettings{
- .actions_settings
- = ExpressionActionsSettings{.can_compile_expressions = true,
.min_count_to_compile_expression = 3,
- .compile_expressions =
CompileExpressions::yes},
- .process_list_element = query_status});
-
- LOG_DEBUG(logger, "clickhouse plan after optimization:\n{}",
PlanUtil::explainPlan(*current_query_plan));
- query_pipeline =
QueryPipelineBuilder::getPipeline(std::move(*pipeline_builder));
- LOG_DEBUG(logger, "clickhouse pipeline:\n{}",
QueryPipelineUtil::explainPipeline(query_pipeline));
- auto t_pipeline = stopwatch.elapsedMicroseconds();
-
- executor = std::make_unique<PullingPipelineExecutor>(query_pipeline);
- auto t_executor = stopwatch.elapsedMicroseconds() - t_pipeline;
- stopwatch.stop();
- LOG_INFO(
- logger,
- "build pipeline {} ms; create executor {} ms;",
- t_pipeline / 1000.0,
- t_executor / 1000.0);
-
- header = current_query_plan->getCurrentDataStream().header.cloneEmpty();
- ch_column_to_spark_row = std::make_unique<CHColumnToSparkRow>();
-}
-
-std::unique_ptr<SparkRowInfo> LocalExecutor::writeBlockToSparkRow(Block &
block)
+std::unique_ptr<SparkRowInfo> LocalExecutor::writeBlockToSparkRow(const Block
& block) const
{
return ch_column_to_spark_row->convertCHColumnToSparkRow(block);
}
bool LocalExecutor::hasNext()
{
- bool has_next;
- try
+ size_t columns = currentBlock().columns();
+ if (columns == 0 || isConsumed())
{
- size_t columns = currentBlock().columns();
- if (columns == 0 || isConsumed())
- {
- auto empty_block = header.cloneEmpty();
- setCurrentBlock(empty_block);
- has_next = executor->pull(currentBlock());
- produce();
- }
- else
- {
- has_next = true;
- }
- }
- catch (Exception & e)
- {
- LOG_ERROR(
- &Poco::Logger::get("LocalExecutor"),
- "LocalExecutor run query plan failed with message: {}. Plan
Explained: \n{}",
- e.message(),
- PlanUtil::explainPlan(*current_query_plan));
- throw;
+ auto empty_block = header.cloneEmpty();
+ setCurrentBlock(empty_block);
+ bool has_next = executor->pull(currentBlock());
+ produce();
+ return has_next;
}
- return has_next;
+ return true;
}
SparkRowInfoPtr LocalExecutor::next()
@@ -2246,12 +2170,17 @@ Block & LocalExecutor::getHeader()
return header;
}
-LocalExecutor::LocalExecutor(ContextPtr context_)
- : context(context_)
+LocalExecutor::LocalExecutor(const ContextPtr & context_, QueryPlanPtr
query_plan, QueryPipeline && pipeline, const Block & header_)
+ : query_pipeline(std::move(pipeline))
+ , executor(std::make_unique<PullingPipelineExecutor>(query_pipeline))
+ , header(header_)
+ , context(context_)
+ , ch_column_to_spark_row(std::make_unique<CHColumnToSparkRow>())
+ , current_query_plan(std::move(query_plan))
{
}
-std::string LocalExecutor::dumpPipeline()
+std::string LocalExecutor::dumpPipeline() const
{
const auto & processors = query_pipeline.getProcessors();
for (auto & processor : processors)
@@ -2275,12 +2204,8 @@ std::string LocalExecutor::dumpPipeline()
}
NonNullableColumnsResolver::NonNullableColumnsResolver(
- const Block & header_,
- SerializedPlanParser & parser_,
- const substrait::Expression & cond_rel_)
- : header(header_)
- , parser(parser_)
- , cond_rel(cond_rel_)
+ const Block & header_, SerializedPlanParser & parser_, const
substrait::Expression & cond_rel_)
+ : header(header_), parser(parser_), cond_rel(cond_rel_)
{
}
@@ -2352,8 +2277,7 @@ void NonNullableColumnsResolver::visitNonNullable(const
substrait::Expression &
}
std::string NonNullableColumnsResolver::safeGetFunctionName(
- const std::string & function_signature,
- const substrait::Expression_ScalarFunction & function)
+ const std::string & function_signature, const
substrait::Expression_ScalarFunction & function) const
{
try
{
diff --git a/cpp-ch/local-engine/Parser/SerializedPlanParser.h
b/cpp-ch/local-engine/Parser/SerializedPlanParser.h
index 71cdca58a..82e8c4077 100644
--- a/cpp-ch/local-engine/Parser/SerializedPlanParser.h
+++ b/cpp-ch/local-engine/Parser/SerializedPlanParser.h
@@ -218,6 +218,7 @@ DataTypePtr wrapNullableType(bool nullable, DataTypePtr
nested_type);
std::string join(const ActionsDAG::NodeRawConstPtrs & v, char c);
class SerializedPlanParser;
+class LocalExecutor;
// Give a condition expression `cond_rel_`, found all columns with nullability
that must not containt
// null after this filter.
@@ -241,7 +242,7 @@ private:
void visit(const substrait::Expression & expr);
void visitNonNullable(const substrait::Expression & expr);
- String safeGetFunctionName(const String & function_signature, const
substrait::Expression_ScalarFunction & function);
+ String safeGetFunctionName(const String & function_signature, const
substrait::Expression_ScalarFunction & function) const;
};
class SerializedPlanParser
@@ -257,11 +258,21 @@ private:
friend class JoinRelParser;
friend class MergeTreeRelParser;
+ std::unique_ptr<LocalExecutor> createExecutor(DB::QueryPlanPtr query_plan);
+
+ DB::QueryPlanPtr parse(const std::string_view & plan);
+ DB::QueryPlanPtr parse(const substrait::Plan & plan);
+
public:
explicit SerializedPlanParser(const ContextPtr & context);
- DB::QueryPlanPtr parse(const std::string & plan);
- DB::QueryPlanPtr parseJson(const std::string & json_plan);
- DB::QueryPlanPtr parse(std::unique_ptr<substrait::Plan> plan);
+
+ /// UT only
+ DB::QueryPlanPtr parseJson(const std::string_view & json_plan);
+ std::unique_ptr<LocalExecutor> createExecutor(const substrait::Plan &
plan) { return createExecutor(parse((plan))); }
+ ///
+
+ template <bool JsonPlan>
+ std::unique_ptr<LocalExecutor> createExecutor(const std::string_view &
plan);
DB::QueryPlanStepPtr parseReadRealWithLocalFile(const substrait::ReadRel &
rel);
DB::QueryPlanStepPtr parseReadRealWithJavaIter(const substrait::ReadRel &
rel);
@@ -372,7 +383,7 @@ private:
const ActionsDAG::Node *
toFunctionNode(ActionsDAGPtr actions_dag, const String & function, const
DB::ActionsDAG::NodeRawConstPtrs & args);
// remove nullable after isNotNull
- void removeNullableForRequiredColumns(const std::set<String> &
require_columns, ActionsDAGPtr actions_dag);
+ void removeNullableForRequiredColumns(const std::set<String> &
require_columns, const ActionsDAGPtr & actions_dag) const;
std::string getUniqueName(const std::string & name) { return name + "_" +
std::to_string(name_no++); }
static std::pair<DataTypePtr, Field> parseLiteral(const
substrait::Expression_Literal & literal);
void wrapNullable(
@@ -394,6 +405,12 @@ public:
const ActionsDAG::Node * addColumn(DB::ActionsDAGPtr actions_dag, const
DataTypePtr & type, const Field & field);
};
+template <bool JsonPlan>
+std::unique_ptr<LocalExecutor> SerializedPlanParser::createExecutor(const
std::string_view & plan)
+{
+ return createExecutor(JsonPlan ? parseJson(plan) : parse(plan));
+}
+
struct SparkBuffer
{
char * address;
@@ -403,16 +420,14 @@ struct SparkBuffer
class LocalExecutor : public BlockIterator
{
public:
- LocalExecutor() = default;
- explicit LocalExecutor(ContextPtr context);
+ LocalExecutor(const ContextPtr & context_, QueryPlanPtr query_plan,
QueryPipeline && pipeline, const Block & header_);
~LocalExecutor();
- void execute(QueryPlanPtr query_plan);
SparkRowInfoPtr next();
Block * nextColumnar();
bool hasNext();
- /// Stop execution and wait for pipeline exit, used when task receives
shutdown command or executor receives SIGTERM signal
+ /// Stop execution, used when task receives shutdown command or executor
receives SIGTERM signal
void cancel();
Block & getHeader();
@@ -425,13 +440,13 @@ public:
static void removeExecutor(Int64 handle);
private:
- std::unique_ptr<SparkRowInfo> writeBlockToSparkRow(DB::Block & block);
+ std::unique_ptr<SparkRowInfo> writeBlockToSparkRow(const DB::Block &
block) const;
void asyncCancel();
void waitCancelFinished();
/// Dump processor runtime information to log
- std::string dumpPipeline();
+ std::string dumpPipeline() const;
QueryPipeline query_pipeline;
std::unique_ptr<PullingPipelineExecutor> executor;
@@ -439,7 +454,7 @@ private:
ContextPtr context;
std::unique_ptr<CHColumnToSparkRow> ch_column_to_spark_row;
std::unique_ptr<SparkBuffer> spark_buffer;
- DB::QueryPlanPtr current_query_plan;
+ QueryPlanPtr current_query_plan;
RelMetricPtr metric;
std::vector<QueryPlanPtr> extra_plan_holder;
std::atomic<bool> is_cancelled{false};
diff --git a/cpp-ch/local-engine/local_engine_jni.cpp
b/cpp-ch/local-engine/local_engine_jni.cpp
index bbc467879..9c642d70e 100644
--- a/cpp-ch/local-engine/local_engine_jni.cpp
+++ b/cpp-ch/local-engine/local_engine_jni.cpp
@@ -224,11 +224,9 @@ JNIEXPORT void JNI_OnUnload(JavaVM * vm, void *
/*reserved*/)
JNIEXPORT void
Java_org_apache_gluten_vectorized_ExpressionEvaluatorJniWrapper_nativeInitNative(JNIEnv
* env, jobject, jbyteArray conf_plan)
{
LOCAL_ENGINE_JNI_METHOD_START
- jsize plan_buf_size = env->GetArrayLength(conf_plan);
+ std::string::size_type plan_buf_size = env->GetArrayLength(conf_plan);
jbyte * plan_buf_addr = env->GetByteArrayElements(conf_plan, nullptr);
- std::string plan_str;
- plan_str.assign(reinterpret_cast<const char *>(plan_buf_addr),
plan_buf_size);
- local_engine::BackendInitializerUtil::init(&plan_str);
+ local_engine::BackendInitializerUtil::init({reinterpret_cast<const char
*>(plan_buf_addr), plan_buf_size});
env->ReleaseByteArrayElements(conf_plan, plan_buf_addr, JNI_ABORT);
LOCAL_ENGINE_JNI_METHOD_END(env, )
}
@@ -254,11 +252,9 @@ JNIEXPORT jlong
Java_org_apache_gluten_vectorized_ExpressionEvaluatorJniWrapper_
auto query_context =
local_engine::getAllocator(allocator_id)->query_context;
// by task update new configs ( in case of dynamic config update )
- jsize plan_buf_size = env->GetArrayLength(conf_plan);
+ std::string::size_type plan_buf_size = env->GetArrayLength(conf_plan);
jbyte * plan_buf_addr = env->GetByteArrayElements(conf_plan, nullptr);
- std::string plan_str;
- plan_str.assign(reinterpret_cast<const char *>(plan_buf_addr),
plan_buf_size);
- local_engine::BackendInitializerUtil::updateConfig(query_context,
&plan_str);
+ local_engine::BackendInitializerUtil::updateConfig(query_context,
{reinterpret_cast<const char *>(plan_buf_addr), plan_buf_size});
local_engine::SerializedPlanParser parser(query_context);
jsize iter_num = env->GetArrayLength(iter_arr);
@@ -277,17 +273,14 @@ JNIEXPORT jlong
Java_org_apache_gluten_vectorized_ExpressionEvaluatorJniWrapper_
parser.addSplitInfo(std::string{reinterpret_cast<const char
*>(split_info_addr), split_info_size});
}
- jsize plan_size = env->GetArrayLength(plan);
+ std::string::size_type plan_size = env->GetArrayLength(plan);
jbyte * plan_address = env->GetByteArrayElements(plan, nullptr);
- std::string plan_string;
- plan_string.assign(reinterpret_cast<const char *>(plan_address),
plan_size);
- auto query_plan = parser.parse(plan_string);
- local_engine::LocalExecutor * executor = new
local_engine::LocalExecutor(query_context);
+ local_engine::LocalExecutor * executor
+ = parser.createExecutor<false>({reinterpret_cast<const char
*>(plan_address), plan_size}).release();
local_engine::LocalExecutor::addExecutor(executor);
- LOG_INFO(&Poco::Logger::get("jni"), "Construct LocalExecutor {}",
reinterpret_cast<intptr_t>(executor));
+ LOG_INFO(&Poco::Logger::get("jni"), "Construct LocalExecutor {}",
reinterpret_cast<uintptr_t>(executor));
executor->setMetric(parser.getMetric());
executor->setExtraPlanHolder(parser.extra_plan_holder);
- executor->execute(std::move(query_plan));
env->ReleaseByteArrayElements(plan, plan_address, JNI_ABORT);
env->ReleaseByteArrayElements(conf_plan, plan_buf_addr, JNI_ABORT);
return reinterpret_cast<jlong>(executor);
@@ -932,11 +925,10 @@ JNIEXPORT jlong
Java_org_apache_spark_sql_execution_datasources_CHDatasourceJniW
LOCAL_ENGINE_JNI_METHOD_START
auto query_context =
local_engine::getAllocator(allocator_id)->query_context;
// by task update new configs ( in case of dynamic config update )
- jsize conf_plan_buf_size = env->GetArrayLength(conf_plan);
+ std::string::size_type conf_plan_buf_size = env->GetArrayLength(conf_plan);
jbyte * conf_plan_buf_addr = env->GetByteArrayElements(conf_plan, nullptr);
- std::string conf_plan_str;
- conf_plan_str.assign(reinterpret_cast<const char *>(conf_plan_buf_addr),
conf_plan_buf_size);
- local_engine::BackendInitializerUtil::updateConfig(query_context,
&conf_plan_str);
+ local_engine::BackendInitializerUtil::updateConfig(
+ query_context, {reinterpret_cast<const char *>(conf_plan_buf_addr),
conf_plan_buf_size});
const auto uuid_str = jstring2string(env, uuid_);
const auto task_id = jstring2string(env, task_id_);
@@ -1329,14 +1321,11 @@
Java_org_apache_gluten_vectorized_SimpleExpressionEval_createNativeInstance(JNIE
local_engine::SerializedPlanParser parser(context);
jobject iter = env->NewGlobalRef(input);
parser.addInputIter(iter, false);
- jsize plan_size = env->GetArrayLength(plan);
+ std::string::size_type plan_size = env->GetArrayLength(plan);
jbyte * plan_address = env->GetByteArrayElements(plan, nullptr);
- std::string plan_string;
- plan_string.assign(reinterpret_cast<const char *>(plan_address),
plan_size);
- auto query_plan = parser.parse(plan_string);
- local_engine::LocalExecutor * executor = new
local_engine::LocalExecutor(context);
+ local_engine::LocalExecutor * executor
+ = parser.createExecutor<false>({reinterpret_cast<const char
*>(plan_address), plan_size}).release();
local_engine::LocalExecutor::addExecutor(executor);
- executor->execute(std::move(query_plan));
env->ReleaseByteArrayElements(plan, plan_address, JNI_ABORT);
return reinterpret_cast<jlong>(executor);
LOCAL_ENGINE_JNI_METHOD_END(env, -1)
diff --git a/cpp-ch/local-engine/tests/benchmark_local_engine.cpp
b/cpp-ch/local-engine/tests/benchmark_local_engine.cpp
index 89fa4fa96..208a3b518 100644
--- a/cpp-ch/local-engine/tests/benchmark_local_engine.cpp
+++ b/cpp-ch/local-engine/tests/benchmark_local_engine.cpp
@@ -154,14 +154,11 @@ DB::ContextMutablePtr global_context;
std::move(schema))
.build();
local_engine::SerializedPlanParser parser(global_context);
- auto query_plan = parser.parse(std::move(plan));
- local_engine::LocalExecutor local_executor;
+ auto local_executor = parser.createExecutor(*plan);
state.ResumeTiming();
- local_executor.execute(std::move(query_plan));
- while (local_executor.hasNext())
- {
- local_engine::SparkRowInfoPtr spark_row_info =
local_executor.next();
- }
+
+ while (local_executor->hasNext())
+ local_engine::SparkRowInfoPtr spark_row_info =
local_executor->next();
}
}
@@ -212,13 +209,12 @@ DB::ContextMutablePtr global_context;
std::move(schema))
.build();
local_engine::SerializedPlanParser
parser(SerializedPlanParser::global_context);
- auto query_plan = parser.parse(std::move(plan));
- local_engine::LocalExecutor local_executor;
+ auto local_executor = parser.createExecutor(*plan);
state.ResumeTiming();
- local_executor.execute(std::move(query_plan));
- while (local_executor.hasNext())
+
+ while (local_executor->hasNext())
{
- Block * block = local_executor.nextColumnar();
+ Block * block = local_executor->nextColumnar();
delete block;
}
}
@@ -238,15 +234,10 @@ DB::ContextMutablePtr global_context;
std::ifstream t(path);
std::string str((std::istreambuf_iterator<char>(t)),
std::istreambuf_iterator<char>());
std::cout << "the plan from: " << path << std::endl;
-
- auto query_plan = parser.parse(str);
- local_engine::LocalExecutor local_executor;
+ auto local_executor = parser.createExecutor<false>(str);
state.ResumeTiming();
- local_executor.execute(std::move(query_plan));
- while (local_executor.hasNext())
- {
- [[maybe_unused]] auto * x = local_executor.nextColumnar();
- }
+ while (local_executor->hasNext()) [[maybe_unused]]
+ auto * x = local_executor->nextColumnar();
}
}
@@ -282,14 +273,12 @@ DB::ContextMutablePtr global_context;
std::move(schema))
.build();
local_engine::SerializedPlanParser
parser(SerializedPlanParser::global_context);
- auto query_plan = parser.parse(std::move(plan));
- local_engine::LocalExecutor local_executor;
+
+ auto local_executor = parser.createExecutor(*plan);
state.ResumeTiming();
- local_executor.execute(std::move(query_plan));
- while (local_executor.hasNext())
- {
- local_engine::SparkRowInfoPtr spark_row_info =
local_executor.next();
- }
+
+ while (local_executor->hasNext())
+ local_engine::SparkRowInfoPtr spark_row_info =
local_executor->next();
}
}
@@ -320,16 +309,13 @@ DB::ContextMutablePtr global_context;
.build();
local_engine::SerializedPlanParser
parser(SerializedPlanParser::global_context);
- auto query_plan = parser.parse(std::move(plan));
- local_engine::LocalExecutor local_executor;
-
- local_executor.execute(std::move(query_plan));
+ auto local_executor = parser.createExecutor(*plan);
local_engine::SparkRowToCHColumn converter;
- while (local_executor.hasNext())
+ while (local_executor->hasNext())
{
- local_engine::SparkRowInfoPtr spark_row_info =
local_executor.next();
+ local_engine::SparkRowInfoPtr spark_row_info =
local_executor->next();
state.ResumeTiming();
- auto block =
converter.convertSparkRowInfoToCHColumn(*spark_row_info,
local_executor.getHeader());
+ auto block =
converter.convertSparkRowInfoToCHColumn(*spark_row_info,
local_executor->getHeader());
state.PauseTiming();
}
state.ResumeTiming();
@@ -368,16 +354,13 @@ DB::ContextMutablePtr global_context;
std::move(schema))
.build();
local_engine::SerializedPlanParser
parser(SerializedPlanParser::global_context);
- auto query_plan = parser.parse(std::move(plan));
- local_engine::LocalExecutor local_executor;
-
- local_executor.execute(std::move(query_plan));
+ auto local_executor = parser.createExecutor(*plan);
local_engine::SparkRowToCHColumn converter;
- while (local_executor.hasNext())
+ while (local_executor->hasNext())
{
- local_engine::SparkRowInfoPtr spark_row_info =
local_executor.next();
+ local_engine::SparkRowInfoPtr spark_row_info =
local_executor->next();
state.ResumeTiming();
- auto block =
converter.convertSparkRowInfoToCHColumn(*spark_row_info,
local_executor.getHeader());
+ auto block =
converter.convertSparkRowInfoToCHColumn(*spark_row_info,
local_executor->getHeader());
state.PauseTiming();
}
state.ResumeTiming();
@@ -485,12 +468,8 @@ DB::ContextMutablePtr global_context;
y.reserve(cnt);
for (auto _ : state)
- {
for (i = 0; i < cnt; i++)
- {
y[i] = add(x[i], i);
- }
- }
}
[[maybe_unused]] static void BM_TestSumInline(benchmark::State & state)
@@ -504,12 +483,8 @@ DB::ContextMutablePtr global_context;
y.reserve(cnt);
for (auto _ : state)
- {
for (i = 0; i < cnt; i++)
- {
y[i] = x[i] + i;
- }
- }
}
[[maybe_unused]] static void BM_TestPlus(benchmark::State & state)
@@ -545,9 +520,7 @@ DB::ContextMutablePtr global_context;
block.insert(y);
auto executable_function = function->prepare(arguments);
for (auto _ : state)
- {
auto result =
executable_function->execute(block.getColumnsWithTypeAndName(), type, rows,
false);
- }
}
[[maybe_unused]] static void BM_TestPlusEmbedded(benchmark::State & state)
@@ -847,9 +820,7 @@ QueryPlanPtr joinPlan(QueryPlanPtr left, QueryPlanPtr
right, String left_key, St
ASTPtr rkey = std::make_shared<ASTIdentifier>(right_key);
join->addOnKeys(lkey, rkey, true);
for (const auto & column : join->columnsFromJoinedTable())
- {
join->addJoinedColumn(column);
- }
auto left_keys =
left->getCurrentDataStream().header.getNamesAndTypesList();
join->addJoinedColumnsAndCorrectTypes(left_keys, true);
@@ -920,7 +891,8 @@
BENCHMARK(BM_ParquetRead)->Unit(benchmark::kMillisecond)->Iterations(10);
int main(int argc, char ** argv)
{
- BackendInitializerUtil::init(nullptr);
+ std::string empty;
+ BackendInitializerUtil::init(empty);
SCOPE_EXIT({ BackendFinalizerUtil::finalizeGlobally(); });
::benchmark::Initialize(&argc, argv);
diff --git a/cpp-ch/local-engine/tests/gluten_test_util.h
b/cpp-ch/local-engine/tests/gluten_test_util.h
index d4c16e9fb..dba4496d6 100644
--- a/cpp-ch/local-engine/tests/gluten_test_util.h
+++ b/cpp-ch/local-engine/tests/gluten_test_util.h
@@ -24,6 +24,7 @@
#include <DataTypes/DataTypeString.h>
#include <DataTypes/DataTypesNumber.h>
#include <Interpreters/ActionsDAG.h>
+#include <google/protobuf/util/json_util.h>
#include <parquet/schema.h>
using BlockRowType = DB::ColumnsWithTypeAndName;
@@ -60,6 +61,23 @@ AnotherRowType readParquetSchema(const std::string & file);
DB::ActionsDAGPtr parseFilter(const std::string & filter, const AnotherRowType
& name_and_types);
+namespace pb_util
+{
+template <typename Message>
+std::string JsonStringToBinary(const std::string_view & json)
+{
+ Message message;
+ std::string binary;
+ auto s = google::protobuf::util::JsonStringToMessage(json, &message);
+ if (!s.ok())
+ {
+ const std::string err_msg{s.message()};
+ throw std::runtime_error(err_msg);
+ }
+ message.SerializeToString(&binary);
+ return binary;
+}
+}
}
inline DB::DataTypePtr BIGINT()
diff --git a/cpp-ch/local-engine/tests/gtest_local_engine.cpp
b/cpp-ch/local-engine/tests/gtest_local_engine.cpp
index 2d1807841..962bf9def 100644
--- a/cpp-ch/local-engine/tests/gtest_local_engine.cpp
+++ b/cpp-ch/local-engine/tests/gtest_local_engine.cpp
@@ -16,9 +16,12 @@
*/
#include <fstream>
#include <iostream>
+#include <gluten_test_util.h>
+#include <incbin.h>
+
#include <Builder/SerializedPlanBuilder.h>
-#include <DataTypes/DataTypesNumber.h>
#include <Disks/DiskLocal.h>
+#include <Formats/FormatFactory.h>
#include <Interpreters/Context.h>
#include <Parser/CHColumnToSparkRow.h>
#include <Parser/SerializedPlanParser.h>
@@ -28,7 +31,6 @@
#include <Storages/CustomMergeTreeSink.h>
#include <Storages/CustomStorageMergeTree.h>
#include <Storages/MergeTree/MergeTreeData.h>
-#include <Storages/SubstraitSource/SubstraitFileSource.h>
#include <gtest/gtest.h>
#include <substrait/plan.pb.h>
#include <Common/CHUtil.h>
@@ -84,13 +86,23 @@ TEST(ReadBufferFromFile, seekBackwards)
ASSERT_EQ(x, 8);
}
+INCBIN(resource_embedded_config_json, SOURCE_DIR
"/utils/extern-local-engine/tests/json/gtest_local_engine_config.json");
+
+namespace DB
+{
+void registerOutputFormatParquet(DB::FormatFactory & factory);
+}
+
int main(int argc, char ** argv)
{
- auto * init = new
String("{\"advancedExtensions\":{\"enhancement\":{\"@type\":\"type.googleapis.com/substrait.Expression\",\"literal\":{\"map\":{\"keyValues\":[{\"key\":{\"string\":\"spark.gluten.sql.columnar.backend.ch.runtime_config.logger.level\"},\"value\":{\"string\":\"trace\"}},{\"key\":{\"string\":\"spark.gluten.sql.columnar.backend.ch.runtime_settings.max_bytes_before_external_sort\"},\"value\":{\"string\":\"5368709120\"}},{\"key\":{\"string\":\"spark.hadoop.fs.s3a.endpoint\"
[...]
+
BackendInitializerUtil::init(test::pb_util::JsonStringToBinary<substrait::Plan>(
+ {reinterpret_cast<const char *>(gresource_embedded_config_jsonData),
gresource_embedded_config_jsonSize}));
+
+ auto & factory = FormatFactory::instance();
+ DB::registerOutputFormatParquet(factory);
- BackendInitializerUtil::init_json(std::move(init));
SCOPE_EXIT({ BackendFinalizerUtil::finalizeGlobally(); });
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
-}
+}
\ No newline at end of file
diff --git a/cpp-ch/local-engine/tests/gtest_parser.cpp
b/cpp-ch/local-engine/tests/gtest_parser.cpp
index cbe41c90c..485740191 100644
--- a/cpp-ch/local-engine/tests/gtest_parser.cpp
+++ b/cpp-ch/local-engine/tests/gtest_parser.cpp
@@ -14,307 +14,140 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
+#include <gluten_test_util.h>
+#include <incbin.h>
#include <Parser/SerializedPlanParser.h>
-#include <google/protobuf/util/json_util.h>
#include <gtest/gtest.h>
+
using namespace local_engine;
using namespace DB;
-std::string splitBinaryFromJson(const std::string & json)
+// Plan for https://github.com/ClickHouse/ClickHouse/pull/65234
+INCBIN(resource_embedded_pr_65234_json, SOURCE_DIR
"/utils/extern-local-engine/tests/json/clickhouse_pr_65234.json");
+
+TEST(SerializedPlanParser, PR65234)
{
- std::string binary;
- substrait::ReadRel::LocalFiles local_files;
- auto s =
google::protobuf::util::JsonStringToMessage(absl::string_view(json),
&local_files);
- local_files.SerializeToString(&binary);
- return binary;
+ const std::string split
+ =
R"({"items":[{"uriFile":"file:///home/chang/SourceCode/rebase_gluten/backends-clickhouse/target/scala-2.12/test-classes/tests-working-home/tpch-data/supplier/part-00000-16caa751-9774-470c-bd37-5c84c53373c8-c000.snappy.parquet","length":"84633","parquet":{},"schema":{},"metadataColumns":[{}]}]})";
+ SerializedPlanParser parser(SerializedPlanParser::global_context);
+
parser.addSplitInfo(test::pb_util::JsonStringToBinary<substrait::ReadRel::LocalFiles>(split));
+ auto query_plan
+ = parser.parseJson({reinterpret_cast<const char
*>(gresource_embedded_pr_65234_jsonData),
gresource_embedded_pr_65234_jsonSize});
}
-std::string JsonPlanFor65234()
+#include <Disks/ObjectStorages/HDFS/HDFSObjectStorage.h>
+#include <Parsers/ParserCreateQuery.h>
+#include <Parsers/parseQuery.h>
+#include <Storages/ObjectStorage/HDFS/Configuration.h>
+#include <Storages/ObjectStorage/StorageObjectStorageSink.h>
+
+Chunk testChunk()
{
- // Plan for https://github.com/ClickHouse/ClickHouse/pull/65234
- return R"(
+ auto nameCol = STRING()->createColumn();
+ nameCol->insert("one");
+ nameCol->insert("two");
+ nameCol->insert("three");
+
+ auto valueCol = UINT()->createColumn();
+ valueCol->insert(1);
+ valueCol->insert(2);
+ valueCol->insert(3);
+ MutableColumns x;
+ x.push_back(std::move(nameCol));
+ x.push_back(std::move(valueCol));
+ return {std::move(x), 3};
+}
+
+TEST(LocalExecutor, StorageObjectStorageSink)
{
- "extensions": [{
- "extensionFunction": {
- "functionAnchor": 1,
- "name": "is_not_null:str"
- }
- }, {
- "extensionFunction": {
- "functionAnchor": 2,
- "name": "equal:str_str"
- }
- }, {
- "extensionFunction": {
- "functionAnchor": 3,
- "name": "is_not_null:i64"
- }
- }, {
- "extensionFunction": {
- "name": "and:bool_bool"
- }
- }],
- "relations": [{
- "root": {
- "input": {
- "project": {
- "common": {
- "emit": {
- "outputMapping": [2]
- }
- },
- "input": {
- "filter": {
- "common": {
- "direct": {
- }
- },
- "input": {
- "read": {
- "common": {
- "direct": {
- }
- },
- "baseSchema": {
- "names": ["r_regionkey", "r_name"],
- "struct": {
- "types": [{
- "i64": {
- "nullability": "NULLABILITY_NULLABLE"
- }
- }, {
- "string": {
- "nullability": "NULLABILITY_NULLABLE"
- }
- }]
- },
- "columnTypes": ["NORMAL_COL", "NORMAL_COL"]
- },
- "filter": {
- "scalarFunction": {
- "outputType": {
- "bool": {
- "nullability": "NULLABILITY_NULLABLE"
- }
- },
- "arguments": [{
- "value": {
- "scalarFunction": {
- "outputType": {
- "bool": {
- "nullability": "NULLABILITY_NULLABLE"
- }
- },
- "arguments": [{
- "value": {
- "scalarFunction": {
- "functionReference": 1,
- "outputType": {
- "bool": {
- "nullability": "NULLABILITY_REQUIRED"
- }
- },
- "arguments": [{
- "value": {
- "selection": {
- "directReference": {
- "structField": {
- "field": 1
- }
- }
- }
- }
- }]
- }
- }
- }, {
- "value": {
- "scalarFunction": {
- "functionReference": 2,
- "outputType": {
- "bool": {
- "nullability": "NULLABILITY_NULLABLE"
- }
- },
- "arguments": [{
- "value": {
- "selection": {
- "directReference": {
- "structField": {
- "field": 1
- }
- }
- }
- }
- }, {
- "value": {
- "literal": {
- "string": "EUROPE"
- }
- }
- }]
- }
- }
- }]
- }
- }
- }, {
- "value": {
- "scalarFunction": {
- "functionReference": 3,
- "outputType": {
- "bool": {
- "nullability": "NULLABILITY_REQUIRED"
- }
- },
- "arguments": [{
- "value": {
- "selection": {
- "directReference": {
- "structField": {
- }
- }
- }
- }
- }]
- }
- }
- }]
- }
- },
- "advancedExtension": {
- "optimization": {
- "@type":
"type.googleapis.com/google.protobuf.StringValue",
- "value": "isMergeTree\u003d0\n"
- }
- }
- }
- },
- "condition": {
- "scalarFunction": {
- "outputType": {
- "bool": {
- "nullability": "NULLABILITY_NULLABLE"
- }
- },
- "arguments": [{
- "value": {
- "scalarFunction": {
- "outputType": {
- "bool": {
- "nullability": "NULLABILITY_NULLABLE"
- }
- },
- "arguments": [{
- "value": {
- "scalarFunction": {
- "functionReference": 1,
- "outputType": {
- "bool": {
- "nullability": "NULLABILITY_REQUIRED"
- }
- },
- "arguments": [{
- "value": {
- "selection": {
- "directReference": {
- "structField": {
- "field": 1
- }
- }
- }
- }
- }]
- }
- }
- }, {
- "value": {
- "scalarFunction": {
- "functionReference": 2,
- "outputType": {
- "bool": {
- "nullability": "NULLABILITY_NULLABLE"
- }
- },
- "arguments": [{
- "value": {
- "selection": {
- "directReference": {
- "structField": {
- "field": 1
- }
- }
- }
- }
- }, {
- "value": {
- "literal": {
- "string": "EUROPE"
- }
- }
- }]
- }
- }
- }]
- }
- }
- }, {
- "value": {
- "scalarFunction": {
- "functionReference": 3,
- "outputType": {
- "bool": {
- "nullability": "NULLABILITY_REQUIRED"
- }
- },
- "arguments": [{
- "value": {
- "selection": {
- "directReference": {
- "structField": {
- }
- }
- }
- }
- }]
- }
- }
- }]
- }
- }
- }
- },
- "expressions": [{
- "selection": {
- "directReference": {
- "structField": {
- }
- }
- }
- }]
- }
- },
- "names": ["r_regionkey#72"],
- "outputSchema": {
- "types": [{
- "i64": {
- "nullability": "NULLABILITY_NULLABLE"
- }
- }],
- "nullability": "NULLABILITY_REQUIRED"
- }
- }
- }]
+ /// 0. Create ObjectStorage for HDFS
+ auto settings = SerializedPlanParser::global_context->getSettingsRef();
+ const std::string query
+ = R"(CREATE TABLE hdfs_engine_xxxx (name String, value UInt32)
ENGINE=HDFS('hdfs://localhost:8020/clickhouse/test2', 'Parquet'))";
+ DB::ParserCreateQuery parser;
+ std::string error_message;
+ const char * pos = query.data();
+ auto ast = DB::tryParseQuery(
+ parser,
+ pos,
+ pos + query.size(),
+ error_message,
+ /* hilite = */ false,
+ "QUERY TEST",
+ /* allow_multi_statements = */ false,
+ 0,
+ settings.max_parser_depth,
+ settings.max_parser_backtracks,
+ true);
+ auto & create = ast->as<ASTCreateQuery &>();
+ auto arg = create.storage->children[0];
+ const auto * func = arg->as<const ASTFunction>();
+ EXPECT_TRUE(func && func->name == "HDFS");
+
+ DB::StorageHDFSConfiguration config;
+ StorageObjectStorage::Configuration::initialize(config,
arg->children[0]->children, SerializedPlanParser::global_context, false);
+
+ const std::shared_ptr<DB::HDFSObjectStorage> object_storage
+ =
std::dynamic_pointer_cast<DB::HDFSObjectStorage>(config.createObjectStorage(SerializedPlanParser::global_context,
false));
+ EXPECT_TRUE(object_storage != nullptr);
+
+ RelativePathsWithMetadata files_with_metadata;
+ object_storage->listObjects("/clickhouse", files_with_metadata, 0);
+
+ /// 1. Create ObjectStorageSink
+ DB::StorageObjectStorageSink sink{
+ object_storage, config.clone(), {}, {{STRING(), "name"}, {UINT(),
"value"}}, SerializedPlanParser::global_context, ""};
+
+ /// 2. Create Chunk
+ /// 3. comsume
+ sink.consume(testChunk());
+ sink.onFinish();
}
-)";
+
+namespace DB
+{
+SinkToStoragePtr createFilelinkSink(
+ const StorageMetadataPtr & metadata_snapshot,
+ const String & table_name_for_log,
+ const String & path,
+ CompressionMethod compression_method,
+ const std::optional<FormatSettings> & format_settings,
+ const String & format_name,
+ const ContextPtr & context,
+ int flags);
}
-TEST(SerializedPlanParser, PR65234)
+INCBIN(resource_embedded_readcsv_json, SOURCE_DIR
"/utils/extern-local-engine/tests/json/read_student_option_schema.csv.json");
+TEST(LocalExecutor, StorageFileSink)
{
const std::string split
- =
R"({"items":[{"uriFile":"file:///part-00000-16caa751-9774-470c-bd37-5c84c53373c8-c000.snappy.parquet","length":"84633","parquet":{},"schema":{},"metadataColumns":[{}]}]}")";
+ =
R"({"items":[{"uriFile":"file:///home/chang/SourceCode/rebase_gluten/backends-velox/src/test/resources/datasource/csv/student_option_schema.csv","length":"56","text":{"fieldDelimiter":",","maxBlockSize":"8192","header":"1"},"schema":{"names":["id","name","language"],"struct":{"types":[{"string":{"nullability":"NULLABILITY_NULLABLE"}},{"string":{"nullability":"NULLABILITY_NULLABLE"}},{"string":{"nullability":"NULLABILITY_NULLABLE"}}]}},"metadataColumns":[{}]}]})";
SerializedPlanParser parser(SerializedPlanParser::global_context);
- parser.addSplitInfo(splitBinaryFromJson(split));
- parser.parseJson(JsonPlanFor65234());
-}
+
parser.addSplitInfo(test::pb_util::JsonStringToBinary<substrait::ReadRel::LocalFiles>(split));
+ auto local_executor = parser.createExecutor<true>(
+ {reinterpret_cast<const char *>(gresource_embedded_readcsv_jsonData),
gresource_embedded_readcsv_jsonSize});
+
+ while (local_executor->hasNext())
+ {
+ const Block & x = *local_executor->nextColumnar();
+ EXPECT_EQ(4, x.rows());
+ }
+
+ StorageInMemoryMetadata metadata;
+ metadata.setColumns(ColumnsDescription::fromNamesAndTypes({{"name",
STRING()}, {"value", UINT()}}));
+ StorageMetadataPtr metadata_ptr =
std::make_shared<StorageInMemoryMetadata>(metadata);
+
+ auto sink = createFilelinkSink(
+ metadata_ptr,
+ "test_table",
+ "/tmp/test_table.parquet",
+ CompressionMethod::None,
+ {},
+ "Parquet",
+ SerializedPlanParser::global_context,
+ 0);
+
+ sink->consume(testChunk());
+ sink->onFinish();
+}
\ No newline at end of file
diff --git a/cpp-ch/local-engine/tests/gtest_parser.cpp
b/cpp-ch/local-engine/tests/json/clickhouse_pr_65234.json
similarity index 84%
copy from cpp-ch/local-engine/tests/gtest_parser.cpp
copy to cpp-ch/local-engine/tests/json/clickhouse_pr_65234.json
index cbe41c90c..1c37b68b7 100644
--- a/cpp-ch/local-engine/tests/gtest_parser.cpp
+++ b/cpp-ch/local-engine/tests/json/clickhouse_pr_65234.json
@@ -1,39 +1,3 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-#include <Parser/SerializedPlanParser.h>
-#include <google/protobuf/util/json_util.h>
-#include <gtest/gtest.h>
-
-using namespace local_engine;
-using namespace DB;
-
-std::string splitBinaryFromJson(const std::string & json)
-{
- std::string binary;
- substrait::ReadRel::LocalFiles local_files;
- auto s =
google::protobuf::util::JsonStringToMessage(absl::string_view(json),
&local_files);
- local_files.SerializeToString(&binary);
- return binary;
-}
-
-std::string JsonPlanFor65234()
-{
- // Plan for https://github.com/ClickHouse/ClickHouse/pull/65234
- return R"(
{
"extensions": [{
"extensionFunction": {
@@ -306,15 +270,4 @@ std::string JsonPlanFor65234()
}
}
}]
-}
-)";
-}
-
-TEST(SerializedPlanParser, PR65234)
-{
- const std::string split
- =
R"({"items":[{"uriFile":"file:///part-00000-16caa751-9774-470c-bd37-5c84c53373c8-c000.snappy.parquet","length":"84633","parquet":{},"schema":{},"metadataColumns":[{}]}]}")";
- SerializedPlanParser parser(SerializedPlanParser::global_context);
- parser.addSplitInfo(splitBinaryFromJson(split));
- parser.parseJson(JsonPlanFor65234());
-}
+}
\ No newline at end of file
diff --git a/cpp-ch/local-engine/tests/json/gtest_local_engine_config.json
b/cpp-ch/local-engine/tests/json/gtest_local_engine_config.json
new file mode 100644
index 000000000..10f0ea3df
--- /dev/null
+++ b/cpp-ch/local-engine/tests/json/gtest_local_engine_config.json
@@ -0,0 +1,269 @@
+{
+ "advancedExtensions": {
+ "enhancement": {
+ "@type": "type.googleapis.com/substrait.Expression",
+ "literal": {
+ "map": {
+ "keyValues": [
+ {
+ "key": {
+ "string":
"spark.gluten.sql.columnar.backend.ch.runtime_config.logger.level"
+ },
+ "value": {
+ "string": "test"
+ }
+ },
+ {
+ "key": {
+ "string":
"spark.gluten.sql.columnar.backend.ch.runtime_settings.max_bytes_before_external_sort"
+ },
+ "value": {
+ "string": "5368709120"
+ }
+ },
+ {
+ "key": {
+ "string": "spark.hadoop.fs.s3a.endpoint"
+ },
+ "value": {
+ "string": "localhost:9000"
+ }
+ },
+ {
+ "key": {
+ "string": "spark.gluten.sql.columnar.backend.velox.IOThreads"
+ },
+ "value": {
+ "string": "0"
+ }
+ },
+ {
+ "key": {
+ "string":
"spark.gluten.sql.columnar.backend.ch.runtime_config.hdfs.input_read_timeout"
+ },
+ "value": {
+ "string": "180000"
+ }
+ },
+ {
+ "key": {
+ "string":
"spark.gluten.sql.columnar.backend.ch.runtime_settings.query_plan_enable_optimizations"
+ },
+ "value": {
+ "string": "false"
+ }
+ },
+ {
+ "key": {
+ "string": "spark.gluten.sql.columnar.backend.ch.worker.id"
+ },
+ "value": {
+ "string": "1"
+ }
+ },
+ {
+ "key": {
+ "string": "spark.memory.offHeap.enabled"
+ },
+ "value": {
+ "string": "true"
+ }
+ },
+ {
+ "key": {
+ "string": "spark.hadoop.fs.s3a.iam.role.session.name"
+ },
+ "value": {
+ "string": ""
+ }
+ },
+ {
+ "key": {
+ "string":
"spark.gluten.sql.columnar.backend.ch.runtime_config.hdfs.input_connect_timeout"
+ },
+ "value": {
+ "string": "180000"
+ }
+ },
+ {
+ "key": {
+ "string": "spark.gluten.sql.columnar.shuffle.codec"
+ },
+ "value": {
+ "string": ""
+ }
+ },
+ {
+ "key": {
+ "string":
"spark.gluten.sql.columnar.backend.ch.runtime_config.local_engine.settings.log_processors_profiles"
+ },
+ "value": {
+ "string": "true"
+ }
+ },
+ {
+ "key": {
+ "string": "spark.gluten.memory.offHeap.size.in.bytes"
+ },
+ "value": {
+ "string": "10737418240"
+ }
+ },
+ {
+ "key": {
+ "string": "spark.gluten.sql.columnar.shuffle.codecBackend"
+ },
+ "value": {
+ "string": ""
+ }
+ },
+ {
+ "key": {
+ "string": "spark.sql.orc.compression.codec"
+ },
+ "value": {
+ "string": "snappy"
+ }
+ },
+ {
+ "key": {
+ "string":
"spark.gluten.sql.columnar.backend.ch.runtime_settings.max_bytes_before_external_group_by"
+ },
+ "value": {
+ "string": "5368709120"
+ }
+ },
+ {
+ "key": {
+ "string": "spark.hadoop.input.write.timeout"
+ },
+ "value": {
+ "string": "180000"
+ }
+ },
+ {
+ "key": {
+ "string": "spark.hadoop.fs.s3a.secret.key"
+ },
+ "value": {
+ "string": ""
+ }
+ },
+ {
+ "key": {
+ "string": "spark.hadoop.fs.s3a.access.key"
+ },
+ "value": {
+ "string": ""
+ }
+ },
+ {
+ "key": {
+ "string":
"spark.gluten.sql.columnar.backend.ch.runtime_config.hdfs.dfs_client_log_severity"
+ },
+ "value": {
+ "string": "INFO"
+ }
+ },
+ {
+ "key": {
+ "string": "spark.hadoop.fs.s3a.path.style.access"
+ },
+ "value": {
+ "string": "true"
+ }
+ },
+ {
+ "key": {
+ "string":
"spark.gluten.sql.columnar.backend.ch.runtime_config.timezone"
+ },
+ "value": {
+ "string": "Asia/Shanghai"
+ }
+ },
+ {
+ "key": {
+ "string": "spark.hadoop.input.read.timeout"
+ },
+ "value": {
+ "string": "180000"
+ }
+ },
+ {
+ "key": {
+ "string": "spark.hadoop.fs.s3a.use.instance.credentials"
+ },
+ "value": {
+ "string": "false"
+ }
+ },
+ {
+ "key": {
+ "string":
"spark.gluten.sql.columnar.backend.ch.runtime_settings.output_format_orc_compression_method"
+ },
+ "value": {
+ "string": "snappy"
+ }
+ },
+ {
+ "key": {
+ "string": "spark.hadoop.fs.s3a.iam.role"
+ },
+ "value": {
+ "string": ""
+ }
+ },
+ {
+ "key": {
+ "string": "spark.gluten.memory.task.offHeap.size.in.bytes"
+ },
+ "value": {
+ "string": "10737418240"
+ }
+ },
+ {
+ "key": {
+ "string": "spark.hadoop.input.connect.timeout"
+ },
+ "value": {
+ "string": "180000"
+ }
+ },
+ {
+ "key": {
+ "string": "spark.hadoop.dfs.client.log.severity"
+ },
+ "value": {
+ "string": "INFO"
+ }
+ },
+ {
+ "key": {
+ "string":
"spark.gluten.sql.columnar.backend.velox.SplitPreloadPerDriver"
+ },
+ "value": {
+ "string": "2"
+ }
+ },
+ {
+ "key": {
+ "string":
"spark.gluten.sql.columnar.backend.ch.runtime_config.hdfs.input_write_timeout"
+ },
+ "value": {
+ "string": "180000"
+ }
+ },
+ {
+ "key": {
+ "string": "spark.hadoop.fs.s3a.connection.ssl.enabled"
+ },
+ "value": {
+ "string": "false"
+ }
+ }
+ ]
+ }
+ }
+ }
+ }
+}
\ No newline at end of file
diff --git a/cpp-ch/local-engine/tests/json/read_student_option_schema.csv.json
b/cpp-ch/local-engine/tests/json/read_student_option_schema.csv.json
new file mode 100644
index 000000000..f9518d390
--- /dev/null
+++ b/cpp-ch/local-engine/tests/json/read_student_option_schema.csv.json
@@ -0,0 +1,77 @@
+{
+ "relations": [
+ {
+ "root": {
+ "input": {
+ "read": {
+ "common": {
+ "direct": {}
+ },
+ "baseSchema": {
+ "names": [
+ "id",
+ "name",
+ "language"
+ ],
+ "struct": {
+ "types": [
+ {
+ "string": {
+ "nullability": "NULLABILITY_NULLABLE"
+ }
+ },
+ {
+ "string": {
+ "nullability": "NULLABILITY_NULLABLE"
+ }
+ },
+ {
+ "string": {
+ "nullability": "NULLABILITY_NULLABLE"
+ }
+ }
+ ]
+ },
+ "columnTypes": [
+ "NORMAL_COL",
+ "NORMAL_COL",
+ "NORMAL_COL"
+ ]
+ },
+ "advancedExtension": {
+ "optimization": {
+ "@type": "type.googleapis.com/google.protobuf.StringValue",
+ "value": "isMergeTree=0\n"
+ }
+ }
+ }
+ },
+ "names": [
+ "id#20",
+ "name#21",
+ "language#22"
+ ],
+ "outputSchema": {
+ "types": [
+ {
+ "string": {
+ "nullability": "NULLABILITY_NULLABLE"
+ }
+ },
+ {
+ "string": {
+ "nullability": "NULLABILITY_NULLABLE"
+ }
+ },
+ {
+ "string": {
+ "nullability": "NULLABILITY_NULLABLE"
+ }
+ }
+ ],
+ "nullability": "NULLABILITY_REQUIRED"
+ }
+ }
+ }
+ ]
+}
\ No newline at end of file
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
index 9a37c4a40..3ca5e0313 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
@@ -430,7 +430,9 @@ trait SparkPlanExecApi {
*
* @return
*/
- def genExtendedColumnarPostRules(): List[SparkSession => Rule[SparkPlan]]
+ def genExtendedColumnarPostRules(): List[SparkSession => Rule[SparkPlan]] = {
+ SparkShimLoader.getSparkShims.getExtendedColumnarPostRules() ::: List()
+ }
def genInjectPostHocResolutionRules(): List[SparkSession =>
Rule[LogicalPlan]]
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/utils/SubstraitPlanPrinterUtil.scala
b/gluten-core/src/main/scala/org/apache/gluten/utils/SubstraitPlanPrinterUtil.scala
index 77d5d55f6..a6ec7cb21 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/utils/SubstraitPlanPrinterUtil.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/utils/SubstraitPlanPrinterUtil.scala
@@ -24,37 +24,34 @@ import io.substrait.proto.{NamedStruct, Plan}
object SubstraitPlanPrinterUtil extends Logging {
- /** Transform Substrait Plan to json format. */
- def substraitPlanToJson(substraintPlan: Plan): String = {
+ private def typeRegistry(
+ d: com.google.protobuf.Descriptors.Descriptor):
com.google.protobuf.TypeRegistry = {
val defaultRegistry = WrappersProto.getDescriptor.getMessageTypes
- val registry = com.google.protobuf.TypeRegistry
+ com.google.protobuf.TypeRegistry
.newBuilder()
- .add(substraintPlan.getDescriptorForType())
+ .add(d)
.add(defaultRegistry)
.build()
- JsonFormat.printer.usingTypeRegistry(registry).print(substraintPlan)
+ }
+ private def MessageToJson(message: com.google.protobuf.Message): String = {
+ val registry = typeRegistry(message.getDescriptorForType)
+ JsonFormat.printer.usingTypeRegistry(registry).print(message)
}
- def substraitNamedStructToJson(substraintPlan: NamedStruct): String = {
- val defaultRegistry = WrappersProto.getDescriptor.getMessageTypes
- val registry = com.google.protobuf.TypeRegistry
- .newBuilder()
- .add(substraintPlan.getDescriptorForType())
- .add(defaultRegistry)
- .build()
- JsonFormat.printer.usingTypeRegistry(registry).print(substraintPlan)
+ /** Transform Substrait Plan to json format. */
+ def substraitPlanToJson(substraitPlan: Plan): String = {
+ MessageToJson(substraitPlan)
+ }
+
+ def substraitNamedStructToJson(namedStruct: NamedStruct): String = {
+ MessageToJson(namedStruct)
}
/** Transform substrait plan json string to PlanNode */
def jsonToSubstraitPlan(planJson: String): Plan = {
try {
val builder = Plan.newBuilder()
- val defaultRegistry = WrappersProto.getDescriptor.getMessageTypes
- val registry = com.google.protobuf.TypeRegistry
- .newBuilder()
- .add(builder.getDescriptorForType)
- .add(defaultRegistry)
- .build()
+ val registry = typeRegistry(builder.getDescriptorForType)
JsonFormat.parser().usingTypeRegistry(registry).merge(planJson, builder)
builder.build()
} catch {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]