This is an automated email from the ASF dual-hosted git repository.

huaxingao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 2ef738205c0 [SPARK-39914][SQL] Add DS V2 Filter to V1 Filter conversion
2ef738205c0 is described below

commit 2ef738205c0d4598a577a248afc117ac0844f3ad
Author: huaxingao <[email protected]>
AuthorDate: Mon Aug 1 11:23:13 2022 -0700

    [SPARK-39914][SQL] Add DS V2 Filter to V1 Filter conversion
    
    ### What changes were proposed in this pull request?
    Add util methods to convert DS V2 Filter to V1 Filter.
    
    ### Why are the changes needed?
    Provide convenient methods to convert V2 to V1 Filters. These methods can 
be used by 
[`SupportsRuntimeFiltering`](https://github.com/apache/spark/pull/36918/files#diff-0d3268f351817ca948e75e7b6641e5cc67c4d773c3234920a7aa62faf11f6c8e)
 and later be used by `SupportsDelete`
    
    ### Does this PR introduce _any_ user-facing change?
    No. These are intended for internal use only
    
    ### How was this patch tested?
    new tests
    
    Closes #37332 from huaxingao/toV1.
    
    Authored-by: huaxingao <[email protected]>
    Signed-off-by: huaxingao <[email protected]>
---
 .../sql/internal/connector/PredicateUtils.scala    | 92 +++++++++++++++++++++-
 .../datasources/v2/V2PredicateSuite.scala          | 85 ++++++++++++++++++++
 2 files changed, 174 insertions(+), 3 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/PredicateUtils.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/PredicateUtils.scala
index ace6b30d4cc..263edd82197 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/PredicateUtils.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/PredicateUtils.scala
@@ -19,14 +19,25 @@ package org.apache.spark.sql.internal.connector
 
 import org.apache.spark.sql.catalyst.CatalystTypeConverters
 import org.apache.spark.sql.connector.expressions.{LiteralValue, 
NamedReference}
-import org.apache.spark.sql.connector.expressions.filter.Predicate
-import org.apache.spark.sql.sources.{Filter, In}
+import org.apache.spark.sql.connector.expressions.filter.{And => V2And, Not => 
V2Not, Or => V2Or, Predicate}
+import org.apache.spark.sql.sources.{AlwaysFalse, AlwaysTrue, And, 
EqualNullSafe, EqualTo, Filter, GreaterThan, GreaterThanOrEqual, In, IsNotNull, 
IsNull, LessThan, LessThanOrEqual, Not, Or, StringContains, StringEndsWith, 
StringStartsWith}
+import org.apache.spark.sql.types.StringType
 
 private[sql] object PredicateUtils {
 
   def toV1(predicate: Predicate): Option[Filter] = {
+
+    def isValidBinaryPredicate(): Boolean = {
+      if (predicate.children().length == 2 &&
+        predicate.children()(0).isInstanceOf[NamedReference] &&
+        predicate.children()(1).isInstanceOf[LiteralValue[_]]) {
+        true
+      } else {
+        false
+      }
+    }
+
     predicate.name() match {
-      // TODO: add conversion for other V2 Predicate
       case "IN" if predicate.children()(0).isInstanceOf[NamedReference] =>
         val attribute = predicate.children()(0).toString
         val values = predicate.children().drop(1)
@@ -43,6 +54,81 @@ private[sql] object PredicateUtils {
           Some(In(attribute, Array.empty[Any]))
         }
 
+      case "=" | "<=>" | ">" | "<" | ">=" | "<=" if isValidBinaryPredicate =>
+        val attribute = predicate.children()(0).toString
+        val value = predicate.children()(1).asInstanceOf[LiteralValue[_]]
+        val v1Value = CatalystTypeConverters.convertToScala(value.value, 
value.dataType)
+        val v1Filter = predicate.name() match {
+          case "=" => EqualTo(attribute, v1Value)
+          case "<=>" => EqualNullSafe(attribute, v1Value)
+          case ">" => GreaterThan(attribute, v1Value)
+          case ">=" => GreaterThanOrEqual(attribute, v1Value)
+          case "<" => LessThan(attribute, v1Value)
+          case "<=" => LessThanOrEqual(attribute, v1Value)
+        }
+        Some(v1Filter)
+
+      case "IS_NULL" | "IS_NOT_NULL" if predicate.children().length == 1 &&
+          predicate.children()(0).isInstanceOf[NamedReference] =>
+        val attribute = predicate.children()(0).toString
+        val v1Filter = predicate.name() match {
+          case "IS_NULL" => IsNull(attribute)
+          case "IS_NOT_NULL" => IsNotNull(attribute)
+        }
+        Some(v1Filter)
+
+      case "STARTS_WITH" | "ENDS_WITH" | "CONTAINS" if isValidBinaryPredicate 
=>
+        val attribute = predicate.children()(0).toString
+        val value = predicate.children()(1).asInstanceOf[LiteralValue[_]]
+        if (!value.dataType.sameType(StringType)) return None
+        val v1Value = value.value.toString
+        val v1Filter = predicate.name() match {
+          case "STARTS_WITH" =>
+            StringStartsWith(attribute, v1Value)
+          case "ENDS_WITH" =>
+            StringEndsWith(attribute, v1Value)
+          case "CONTAINS" =>
+            StringContains(attribute, v1Value)
+        }
+        Some(v1Filter)
+
+      case "ALWAYS_TRUE" | "ALWAYS_FALSE" if predicate.children().isEmpty =>
+        val v1Filter = predicate.name() match {
+          case "ALWAYS_TRUE" => AlwaysTrue()
+          case "ALWAYS_FALSE" => AlwaysFalse()
+        }
+        Some(v1Filter)
+
+      case "AND" =>
+        val and = predicate.asInstanceOf[V2And]
+        val left = toV1(and.left())
+        val right = toV1(and.right())
+        if (left.nonEmpty && right.nonEmpty) {
+          Some(And(left.get, right.get))
+        } else {
+          None
+        }
+
+      case "OR" =>
+        val or = predicate.asInstanceOf[V2Or]
+        val left = toV1(or.left())
+        val right = toV1(or.right())
+        if (left.nonEmpty && right.nonEmpty) {
+          Some(Or(left.get, right.get))
+        } else if (left.nonEmpty) {
+          left
+        } else {
+          right
+        }
+
+      case "NOT" =>
+        val child = toV1(predicate.asInstanceOf[V2Not].child())
+        if (child.nonEmpty) {
+          Some(Not(child.get))
+        } else {
+          None
+        }
+
       case _ => None
     }
   }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2PredicateSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2PredicateSuite.scala
index 2df8b8e56c4..de556c50f5d 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2PredicateSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2PredicateSuite.scala
@@ -21,6 +21,7 @@ import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.connector.expressions.{Expression, FieldReference, 
Literal, LiteralValue}
 import org.apache.spark.sql.connector.expressions.filter._
 import org.apache.spark.sql.execution.datasources.v2.V2PredicateSuite.ref
+import org.apache.spark.sql.internal.connector.PredicateUtils
 import org.apache.spark.sql.sources.{AlwaysFalse => V1AlwaysFalse, AlwaysTrue 
=> V1AlwaysTrue, And => V1And, EqualNullSafe, EqualTo, GreaterThan, 
GreaterThanOrEqual, In, IsNotNull, IsNull, LessThan, LessThanOrEqual, Not => 
V1Not, Or => V1Or, StringContains, StringEndsWith, StringStartsWith}
 import org.apache.spark.sql.types.{IntegerType, StringType}
 import org.apache.spark.unsafe.types.UTF8String
@@ -34,6 +35,9 @@ class V2PredicateSuite extends SparkFunSuite {
     assert(predicate1.describe.equals("a.B = 1"))
     val v1Filter1 = EqualTo(ref("a", "B").describe(), 1)
     assert(v1Filter1.toV2 == predicate1)
+    assert(PredicateUtils.toV1(predicate1).get == v1Filter1)
+    assert(PredicateUtils.toV1(v1Filter1.toV2).get == v1Filter1)
+    assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1)
 
     val predicate2 =
       new Predicate("=", Array[Expression](ref("a", "b.c"), LiteralValue(1, 
IntegerType)))
@@ -41,6 +45,9 @@ class V2PredicateSuite extends SparkFunSuite {
     assert(predicate2.describe.equals("a.`b.c` = 1"))
     val v1Filter2 = EqualTo(ref("a", "b.c").describe(), 1)
     assert(v1Filter2.toV2 == predicate2)
+    assert(PredicateUtils.toV1(predicate2).get == v1Filter2)
+    assert(PredicateUtils.toV1(v1Filter2.toV2).get == v1Filter2)
+    assert(PredicateUtils.toV1(predicate2).get.toV2 == predicate2)
 
     val predicate3 =
       new Predicate("=", Array[Expression](ref("`a`.b", "c"), LiteralValue(1, 
IntegerType)))
@@ -48,6 +55,9 @@ class V2PredicateSuite extends SparkFunSuite {
     assert(predicate3.describe.equals("```a``.b`.c = 1"))
     val v1Filter3 = EqualTo(ref("`a`.b", "c").describe(), 1)
     assert(v1Filter3.toV2 == predicate3)
+    assert(PredicateUtils.toV1(predicate3).get == v1Filter3)
+    assert(PredicateUtils.toV1(v1Filter3.toV2).get == v1Filter3)
+    assert(PredicateUtils.toV1(predicate3).get.toV2 == predicate3)
   }
 
   test("AlwaysTrue") {
@@ -59,6 +69,9 @@ class V2PredicateSuite extends SparkFunSuite {
 
     val v1Filter = V1AlwaysTrue
     assert(v1Filter.toV2 == predicate1)
+    assert(PredicateUtils.toV1(predicate1).get == v1Filter)
+    assert(PredicateUtils.toV1(v1Filter.toV2).get == v1Filter)
+    assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1)
   }
 
   test("AlwaysFalse") {
@@ -70,6 +83,9 @@ class V2PredicateSuite extends SparkFunSuite {
 
     val v1Filter = V1AlwaysFalse
     assert(v1Filter.toV2 == predicate1)
+    assert(PredicateUtils.toV1(predicate1).get == v1Filter)
+    assert(PredicateUtils.toV1(v1Filter.toV2).get == v1Filter)
+    assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1)
   }
 
   test("EqualTo") {
@@ -81,6 +97,9 @@ class V2PredicateSuite extends SparkFunSuite {
 
     val v1Filter = EqualTo("a", 1)
     assert(v1Filter.toV2 == predicate1)
+    assert(PredicateUtils.toV1(predicate1).get == v1Filter)
+    assert(PredicateUtils.toV1(v1Filter.toV2).get == v1Filter)
+    assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1)
   }
 
   test("EqualNullSafe") {
@@ -92,6 +111,9 @@ class V2PredicateSuite extends SparkFunSuite {
 
     val v1Filter = EqualNullSafe("a", 1)
     assert(v1Filter.toV2 == predicate1)
+    assert(PredicateUtils.toV1(predicate1).get == v1Filter)
+    assert(PredicateUtils.toV1(v1Filter.toV2).get == v1Filter)
+    assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1)
   }
 
   test("LessThan") {
@@ -103,6 +125,9 @@ class V2PredicateSuite extends SparkFunSuite {
 
     val v1Filter = LessThan("a", 1)
     assert(v1Filter.toV2 == predicate1)
+    assert(PredicateUtils.toV1(predicate1).get == v1Filter)
+    assert(PredicateUtils.toV1(v1Filter.toV2).get == v1Filter)
+    assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1)
   }
 
   test("LessThanOrEqual") {
@@ -114,6 +139,9 @@ class V2PredicateSuite extends SparkFunSuite {
 
     val v1Filter = LessThanOrEqual("a", 1)
     assert(v1Filter.toV2 == predicate1)
+    assert(PredicateUtils.toV1(predicate1).get == v1Filter)
+    assert(PredicateUtils.toV1(v1Filter.toV2).get == v1Filter)
+    assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1)
   }
 
   test("GreatThan") {
@@ -125,6 +153,9 @@ class V2PredicateSuite extends SparkFunSuite {
 
     val v1Filter = GreaterThan("a", 1)
     assert(v1Filter.toV2 == predicate1)
+    assert(PredicateUtils.toV1(predicate1).get == v1Filter)
+    assert(PredicateUtils.toV1(v1Filter.toV2).get == v1Filter)
+    assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1)
   }
 
   test("GreatThanOrEqual") {
@@ -136,6 +167,9 @@ class V2PredicateSuite extends SparkFunSuite {
 
     val v1Filter = GreaterThanOrEqual("a", 1)
     assert(v1Filter.toV2 == predicate1)
+    assert(PredicateUtils.toV1(predicate1).get == v1Filter)
+    assert(PredicateUtils.toV1(v1Filter.toV2).get == v1Filter)
+    assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1)
   }
 
   test("In") {
@@ -161,9 +195,15 @@ class V2PredicateSuite extends SparkFunSuite {
 
     val v1Filter1 = In("a", Array(1, 2, 3, 4))
     assert(v1Filter1.toV2 == predicate1)
+    assert(PredicateUtils.toV1(predicate1).get == v1Filter1)
+    assert(PredicateUtils.toV1(v1Filter1.toV2).get == v1Filter1)
+    assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1)
 
     val v1Filter2 = In("a", values.map(_.value()))
     assert(v1Filter2.toV2 == predicate3)
+    assert(PredicateUtils.toV1(predicate3).get == v1Filter2)
+    assert(PredicateUtils.toV1(v1Filter2.toV2).get == v1Filter2)
+    assert(PredicateUtils.toV1(predicate3).get.toV2 == predicate3)
   }
 
   test("IsNull") {
@@ -175,6 +215,9 @@ class V2PredicateSuite extends SparkFunSuite {
 
     val v1Filter = IsNull("a")
     assert(v1Filter.toV2 == predicate1)
+    assert(PredicateUtils.toV1(predicate1).get == v1Filter)
+    assert(PredicateUtils.toV1(v1Filter.toV2).get == v1Filter)
+    assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1)
   }
 
   test("IsNotNull") {
@@ -186,6 +229,9 @@ class V2PredicateSuite extends SparkFunSuite {
 
     val v1Filter = IsNotNull("a")
     assert(v1Filter.toV2 == predicate1)
+    assert(PredicateUtils.toV1(predicate1).get == v1Filter)
+    assert(PredicateUtils.toV1(v1Filter.toV2).get == v1Filter)
+    assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1)
   }
 
   test("Not") {
@@ -199,6 +245,14 @@ class V2PredicateSuite extends SparkFunSuite {
 
     val v1Filter = V1Not(LessThan("a", 1))
     assert(v1Filter.toV2 == predicate1)
+    assert(PredicateUtils.toV1(predicate1).get == v1Filter)
+    assert(PredicateUtils.toV1(v1Filter.toV2).get == v1Filter)
+    assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1)
+
+    val predicate3 = new Not(
+      new Predicate("=", Array[Expression](LiteralValue(1, IntegerType),
+        LiteralValue(1, IntegerType))))
+    assert(PredicateUtils.toV1(predicate3) == None)
   }
 
   test("And") {
@@ -214,6 +268,15 @@ class V2PredicateSuite extends SparkFunSuite {
 
     val v1Filter = V1And(EqualTo("a", 1), EqualTo("b", 1))
     assert(v1Filter.toV2 == predicate1)
+    assert(PredicateUtils.toV1(predicate1).get == v1Filter)
+    assert(PredicateUtils.toV1(v1Filter.toV2).get == v1Filter)
+    assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1)
+
+    val predicate3 = new And(
+      new Predicate("=", Array[Expression](ref("a"), LiteralValue(1, 
IntegerType))),
+      new Predicate("=", Array[Expression](LiteralValue(1, IntegerType),
+        LiteralValue(1, IntegerType))))
+    assert(PredicateUtils.toV1(predicate3) == None)
   }
 
   test("Or") {
@@ -229,6 +292,19 @@ class V2PredicateSuite extends SparkFunSuite {
 
     val v1Filter = V1Or(EqualTo("a", 1), EqualTo("b", 1))
     assert(v1Filter.toV2.equals(predicate1))
+    assert(PredicateUtils.toV1(predicate1).get == v1Filter)
+    assert(PredicateUtils.toV1(v1Filter.toV2).get == v1Filter)
+    assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1)
+
+    val left = new Predicate("=", Array[Expression](ref("a"), LiteralValue(1, 
IntegerType)))
+    val predicate3 = new Or(left,
+      new Predicate("=", Array[Expression](LiteralValue(1, IntegerType))))
+    assert(PredicateUtils.toV1(predicate3) == PredicateUtils.toV1(left))
+
+    val predicate4 = new Or(
+      new Predicate("=", Array[Expression](LiteralValue(1, IntegerType))),
+      new Predicate("=", Array[Expression](LiteralValue(1, IntegerType))))
+    assert(PredicateUtils.toV1(predicate4) == None)
   }
 
   test("StringStartsWith") {
@@ -243,6 +319,9 @@ class V2PredicateSuite extends SparkFunSuite {
 
     val v1Filter = StringStartsWith("a", "str")
     assert(v1Filter.toV2.equals(predicate1))
+    assert(PredicateUtils.toV1(predicate1).get == v1Filter)
+    assert(PredicateUtils.toV1(v1Filter.toV2).get == v1Filter)
+    assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1)
   }
 
   test("StringEndsWith") {
@@ -257,6 +336,9 @@ class V2PredicateSuite extends SparkFunSuite {
 
     val v1Filter = StringEndsWith("a", "str")
     assert(v1Filter.toV2.equals(predicate1))
+    assert(PredicateUtils.toV1(predicate1).get == v1Filter)
+    assert(PredicateUtils.toV1(v1Filter.toV2).get == v1Filter)
+    assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1)
   }
 
   test("StringContains") {
@@ -271,6 +353,9 @@ class V2PredicateSuite extends SparkFunSuite {
 
     val v1Filter = StringContains("a", "str")
     assert(v1Filter.toV2.equals(predicate1))
+    assert(PredicateUtils.toV1(predicate1).get == v1Filter)
+    assert(PredicateUtils.toV1(v1Filter.toV2).get == v1Filter)
+    assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1)
   }
 }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to