This is an automated email from the ASF dual-hosted git repository.
philo pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new 671bac5e5d [GLUTEN-8304][CORE] Add an optimization rule to collapse
nested get_json_object functions (#8305)
671bac5e5d is described below
commit 671bac5e5d4e45d40ec178b52e4658f6dc895927
Author: kevinyhzou <[email protected]>
AuthorDate: Tue Jan 7 10:10:23 2025 +0800
[GLUTEN-8304][CORE] Add an optimization rule to collapse nested
get_json_object functions (#8305)
---
.../gluten/backendsapi/clickhouse/CHRuleApi.scala | 1 +
.../execution/GlutenFunctionValidateSuite.scala | 85 +++++++++++++++-
.../gluten/backendsapi/velox/VeloxRuleApi.scala | 1 +
.../CollapseGetJsonObjectExpressionRule.scala | 111 +++++++++++++++++++++
.../org/apache/gluten/config/GlutenConfig.scala | 10 ++
5 files changed, 207 insertions(+), 1 deletion(-)
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala
index e6ee3f79ee..016664ed97 100644
---
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala
@@ -61,6 +61,7 @@ object CHRuleApi {
(spark, parserInterface) => new GlutenClickhouseSqlParser(spark,
parserInterface))
injector.injectResolutionRule(spark => new
RewriteToDateExpresstionRule(spark))
injector.injectResolutionRule(spark => new
RewriteDateTimestampComparisonRule(spark))
+ injector.injectResolutionRule(spark => new
CollapseGetJsonObjectExpressionRule(spark))
injector.injectOptimizerRule(spark => new
CommonSubexpressionEliminateRule(spark))
injector.injectOptimizerRule(spark => new ExtendedColumnPruning(spark))
injector.injectOptimizerRule(spark =>
CHAggregateFunctionRewriteRule(spark))
diff --git
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala
index e1287c8b6d..d900bc000c 100644
---
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala
+++
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala
@@ -21,7 +21,9 @@ import org.apache.gluten.utils.UTSystemParameters
import org.apache.spark.SparkConf
import org.apache.spark.sql.{DataFrame, Row, TestUtils}
+import org.apache.spark.sql.catalyst.expressions.{Expression, GetJsonObject,
Literal}
import org.apache.spark.sql.catalyst.optimizer.{ConstantFolding,
NullPropagation}
+import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan,
Project}
import
org.apache.spark.sql.execution.datasources.v2.clickhouse.ClickHouseConfig
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
@@ -90,7 +92,9 @@ class GlutenFunctionValidateSuite extends
GlutenClickHouseWholeStageTransformerS
Row(1.011, 5, "{\"a\":\"b\", \"x\":{\"i\":1}}"),
Row(1.011, 5, "{\"a\":\"b\", \"x\":{\"i\":2}}"),
Row(1.011, 5, "{\"a\":1, \"x\":{\"i\":2}}"),
- Row(1.0, 5, "{\"a\":\"{\\\"x\\\":5}\"}")
+ Row(1.0, 5, "{\"a\":\"{\\\"x\\\":5}\"}"),
+ Row(1.0, 6, "{\"a\":{\"y\": 5, \"z\": {\"m\":1, \"n\": {\"p\":
\"k\"}}}"),
+ Row(1.0, 7, "{\"a\":[{\"y\": 5}, {\"z\":[{\"m\":1,
\"n\":{\"p\":\"k\"}}]}]}")
))
val dfParquet = spark.createDataFrame(data, schema)
dfParquet
@@ -268,6 +272,85 @@ class GlutenFunctionValidateSuite extends
GlutenClickHouseWholeStageTransformerS
}
}
+ test("GLUTEN-8304: Optimize nested get_json_object") {
+ def checkExpression(expr: Expression, path: String): Boolean = {
+ expr match {
+ case g: GetJsonObject
+ if g.path.isInstanceOf[Literal] &&
g.path.dataType.isInstanceOf[StringType] =>
+ g.path.asInstanceOf[Literal].value.toString.equals(path) ||
g.children.exists(
+ c => checkExpression(c, path))
+ case _ =>
+ if (expr.children.isEmpty) {
+ false
+ } else {
+ expr.children.exists(c => checkExpression(c, path))
+ }
+ }
+ }
+ def checkPlan(plan: LogicalPlan, path: String): Boolean = plan match {
+ case p: Project =>
+ p.projectList.exists(x => checkExpression(x, path)) ||
checkPlan(p.child, path)
+ case f: Filter =>
+ checkExpression(f.condition, path) || checkPlan(f.child, path)
+ case _ =>
+ if (plan.children.isEmpty) {
+ false
+ } else {
+ plan.children.exists(c => checkPlan(c, path))
+ }
+ }
+ def checkGetJsonObjectPath(df: DataFrame, path: String): Boolean = {
+ checkPlan(df.queryExecution.analyzed, path)
+ }
+ withSQLConf(("spark.gluten.sql.collapseGetJsonObject.enabled", "true")) {
+ runQueryAndCompare(
+ "select get_json_object(get_json_object(string_field1, '$.a'), '$.y')
" +
+ " from json_test where int_field1 = 6") {
+ x => assert(checkGetJsonObjectPath(x, "$.a.y"))
+ }
+ runQueryAndCompare(
+ "select get_json_object(get_json_object(string_field1, '$[a]'),
'$[y]') " +
+ " from json_test where int_field1 = 6") {
+ x => assert(checkGetJsonObjectPath(x, "$[a][y]"))
+ }
+ runQueryAndCompare(
+ "select get_json_object(get_json_object(get_json_object(string_field1,
" +
+ "'$.a'), '$.y'), '$.z') from json_test where int_field1 = 6") {
+ x => assert(checkGetJsonObjectPath(x, "$.a.y.z"))
+ }
+ runQueryAndCompare(
+ "select get_json_object(get_json_object(get_json_object(string_field1,
'$.a')," +
+ " string_field1), '$.z') from json_test where int_field1 = 6",
+ noFallBack = false
+ )(x => assert(checkGetJsonObjectPath(x, "$.a") &&
checkGetJsonObjectPath(x, "$.z")))
+ runQueryAndCompare(
+ "select get_json_object(get_json_object(get_json_object(string_field1,
" +
+ " string_field1), '$.a'), '$.z') from json_test where int_field1 =
6",
+ noFallBack = false
+ )(x => assert(checkGetJsonObjectPath(x, "$.a.z")))
+ runQueryAndCompare(
+ "select get_json_object(get_json_object(get_json_object(" +
+ " substring(string_field1, 10), '$.a'), '$.z'), string_field1) " +
+ " from json_test where int_field1 = 6",
+ noFallBack = false
+ )(x => assert(checkGetJsonObjectPath(x, "$.a.z")))
+ runQueryAndCompare(
+ "select get_json_object(get_json_object(string_field1, '$.a[0]'),
'$.y') " +
+ " from json_test where int_field1 = 7") {
+ x => assert(checkGetJsonObjectPath(x, "$.a[0].y"))
+ }
+ runQueryAndCompare(
+ "select get_json_object(get_json_object(get_json_object(string_field1,
" +
+ " '$.a[1]'), '$.z[1]'), '$.n') from json_test where int_field1 = 7")
{
+ x => assert(checkGetJsonObjectPath(x, "$.a[1].z[1].n"))
+ }
+ runQueryAndCompare(
+ "select * from json_test where " +
+ " get_json_object(get_json_object(get_json_object(string_field1,
'$.a'), " +
+ "'$.y'), '$.z') != null")(x => assert(checkGetJsonObjectPath(x,
"$.a.y.z")))
+ }
+ }
+
test("Test get_json_object 10") {
runQueryAndCompare("SELECT get_json_object(string_field1, '$.12345') from
json_test") { _ => }
runQueryAndCompare("SELECT get_json_object(string_field1, '$.123.abc')
from json_test") { _ => }
diff --git
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala
index 72d769c999..d36cb6f553 100644
---
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala
+++
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala
@@ -58,6 +58,7 @@ object VeloxRuleApi {
// Inject the regular Spark rules directly.
injector.injectOptimizerRule(CollectRewriteRule.apply)
injector.injectOptimizerRule(HLLRewriteRule.apply)
+ injector.injectOptimizerRule(CollapseGetJsonObjectExpressionRule.apply)
injector.injectPostHocResolutionRule(ArrowConvertorRule.apply)
}
diff --git
a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/CollapseGetJsonObjectExpressionRule.scala
b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/CollapseGetJsonObjectExpressionRule.scala
new file mode 100644
index 0000000000..4c84f42149
--- /dev/null
+++
b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/CollapseGetJsonObjectExpressionRule.scala
@@ -0,0 +1,111 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.extension.columnar
+
+import org.apache.gluten.config.GlutenConfig
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan,
Project}
+import org.apache.spark.sql.catalyst.rules.Rule
+
+/**
+ * The rule is aimed to collapse nested `get_json_object` functions as one for
optimization, e.g.
+ * get_json_object(get_json_object(d, '$.a'), '$.b') => get_json_object(d,
'$.a.b'). And we should
+ * notice that some case can not be applied to this rule:
+ * - get_json_object(get_json_object({"a":"{\\\"x\\\":5}"}', '$.a'), '$.x'),
the json string has
+ * backslashes to escape quotes ;
+ * - get_json_object(get_json_object('{"a.b": 0}', '$.a), '$.b'), the json
key contains dot
+ * character(.) and it's same as the collapsed json path;
+ */
+case class CollapseGetJsonObjectExpressionRule(spark: SparkSession) extends
Rule[LogicalPlan] {
+ override def apply(plan: LogicalPlan): LogicalPlan = {
+ if (
+ plan.resolved
+ && GlutenConfig.get.enableCollapseNestedGetJsonObject
+ ) {
+ visitPlan(plan)
+ } else {
+ plan
+ }
+ }
+
+ private def visitPlan(plan: LogicalPlan): LogicalPlan = plan match {
+ case p: Project =>
+ var newProjectList = Seq.empty[NamedExpression]
+ p.projectList.foreach {
+ case a: Alias if a.child.isInstanceOf[GetJsonObject] =>
+ newProjectList :+=
optimizeNestedFunctions(a).asInstanceOf[NamedExpression]
+ case p =>
+ newProjectList :+= p
+ }
+ val newChild = visitPlan(p.child)
+ Project(newProjectList, newChild)
+ case f: Filter =>
+ val newCond = optimizeNestedFunctions(f.condition)
+ val newChild = visitPlan(f.child)
+ Filter(newCond, newChild)
+ case other =>
+ val children = other.children.map(visitPlan)
+ plan.withNewChildren(children)
+ }
+
+ private def optimizeNestedFunctions(
+ expr: Expression,
+ path: String = "",
+ isNested: Boolean = false): Expression = {
+
+ def getPathLiteral(path: Expression): Option[String] = path match {
+ case l: Literal =>
+ Option.apply(l.value.toString)
+ case _ =>
+ Option.empty
+ }
+
+ expr match {
+ case g: GetJsonObject =>
+ val gPath = getPathLiteral(g.path).orNull
+ var newPath = ""
+ if (gPath != null) {
+ newPath = gPath.replace("$", "") + path
+ }
+ val res = optimizeNestedFunctions(g.json, newPath, isNested = true)
+ if (gPath != null) {
+ res
+ } else {
+ var newChildren = Seq.empty[Expression]
+ newChildren :+= res
+ newChildren :+= g.path
+ val newExpr = g.withNewChildren(newChildren)
+ if (path.nonEmpty) {
+ GetJsonObject(newExpr, Literal.apply("$" + path))
+ } else {
+ newExpr
+ }
+ }
+ case _ =>
+ val newChildren = expr.children.map(x => optimizeNestedFunctions(x,
path))
+ val newExpr = expr.withNewChildren(newChildren)
+ if (isNested && path.nonEmpty) {
+ val pathExpr = Literal.apply("$" + path)
+ GetJsonObject(newExpr, pathExpr)
+ } else {
+ newExpr
+ }
+ }
+ }
+}
diff --git
a/shims/common/src/main/scala/org/apache/gluten/config/GlutenConfig.scala
b/shims/common/src/main/scala/org/apache/gluten/config/GlutenConfig.scala
index 57bb31dd04..f6ed032734 100644
--- a/shims/common/src/main/scala/org/apache/gluten/config/GlutenConfig.scala
+++ b/shims/common/src/main/scala/org/apache/gluten/config/GlutenConfig.scala
@@ -105,6 +105,9 @@ class GlutenConfig(conf: SQLConf) extends Logging {
def enableRewriteDateTimestampComparison: Boolean =
conf.getConf(ENABLE_REWRITE_DATE_TIMESTAMP_COMPARISON)
+ def enableCollapseNestedGetJsonObject: Boolean =
+ conf.getConf(ENABLE_COLLAPSE_GET_JSON_OBJECT)
+
def enableCHRewriteDateConversion: Boolean =
conf.getConf(ENABLE_CH_REWRITE_DATE_CONVERSION)
@@ -1966,6 +1969,13 @@ object GlutenConfig {
.booleanConf
.createWithDefault(true)
+ val ENABLE_COLLAPSE_GET_JSON_OBJECT =
+ buildConf("spark.gluten.sql.collapseGetJsonObject.enabled")
+ .internal()
+ .doc("Collapse nested get_json_object functions as one for
optimization.")
+ .booleanConf
+ .createWithDefault(false)
+
val ENABLE_CH_REWRITE_DATE_CONVERSION =
buildConf("spark.gluten.sql.columnar.backend.ch.rewrite.dateConversion")
.internal()
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]