This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new d9c90887cb9 [SPARK-40780][CONNECT] Add WHERE to Connect proto and DSL
d9c90887cb9 is described below

commit d9c90887cb9ef32d54b3e0edcfffb43ba3d70fa6
Author: Rui Wang <[email protected]>
AuthorDate: Thu Oct 13 21:22:21 2022 +0800

    [SPARK-40780][CONNECT] Add WHERE to Connect proto and DSL
    
    ### What changes were proposed in this pull request?
    
    Add WHERE to Connect proto and DSL.
    
    ### Why are the changes needed?
    
    Improve Connect proto testing coverage.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    UT
    
    Closes #38232 from amaliujia/add_filter_to_dsl.
    
    Authored-by: Rui Wang <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../org/apache/spark/sql/connect/dsl/package.scala | 22 ++++++++++++++++++++++
 .../connect/planner/SparkConnectProtoSuite.scala   | 11 +++++++++++
 .../spark/sql/catalyst/analysis/Analyzer.scala     |  2 +-
 3 files changed, 34 insertions(+), 1 deletion(-)

diff --git 
a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
 
b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
index 80d6e77c9fc..0db8ab96610 100644
--- 
a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
+++ 
b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
@@ -17,6 +17,7 @@
 package org.apache.spark.sql.connect
 
 import scala.collection.JavaConverters._
+import scala.language.implicitConversions
 
 import org.apache.spark.connect.proto
 import org.apache.spark.connect.proto.Join.JoinType
@@ -44,7 +45,20 @@ package object dsl {
     implicit class DslExpression(val expr: proto.Expression) {
       def as(alias: String): proto.Expression = 
proto.Expression.newBuilder().setAlias(
         
proto.Expression.Alias.newBuilder().setName(alias).setExpr(expr)).build()
+
+      def < (other: proto.Expression): proto.Expression =
+        proto.Expression.newBuilder().setUnresolvedFunction(
+          proto.Expression.UnresolvedFunction.newBuilder()
+            .addParts("<")
+            .addArguments(expr)
+            .addArguments(other)
+        ).build()
     }
+
+    implicit def intToLiteral(i: Int): proto.Expression =
+      proto.Expression.newBuilder().setLiteral(
+        proto.Expression.Literal.newBuilder().setI32(i)
+      ).build()
   }
 
   object plans { // scalastyle:ignore
@@ -58,6 +72,14 @@ package object dsl {
         ).build()
       }
 
+      def where(condition: proto.Expression): proto.Relation = {
+        proto.Relation.newBuilder()
+          .setFilter(
+            
proto.Filter.newBuilder().setInput(logicalPlan).setCondition(condition)
+        ).build()
+      }
+
+
       def join(
           otherPlan: proto.Relation,
           joinType: JoinType = JoinType.JOIN_TYPE_INNER,
diff --git 
a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
 
b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
index 510b54cd250..351cc70852a 100644
--- 
a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
+++ 
b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
@@ -51,6 +51,17 @@ class SparkConnectProtoSuite extends PlanTest with 
SparkConnectPlanTest {
     comparePlans(connectPlan.analyze, sparkPlan.analyze, false)
   }
 
+  test("Basic filter") {
+    val connectPlan = {
+      import org.apache.spark.sql.connect.dsl.expressions._
+      import org.apache.spark.sql.connect.dsl.plans._
+      transform(connectTestRelation.where("id".protoAttr < 0))
+    }
+
+    val sparkPlan = sparkTestRelation.where($"id" < 0).analyze
+    comparePlans(connectPlan.analyze, sparkPlan.analyze, false)
+  }
+
   test("Basic joins with different join types") {
     val connectPlan = {
       import org.apache.spark.sql.connect.dsl.plans._
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 62d930dcd20..ae65902e8a6 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -70,7 +70,7 @@ object SimpleAnalyzer extends Analyzer(
     FakeV2SessionCatalog,
     new SessionCatalog(
       new InMemoryCatalog,
-      EmptyFunctionRegistry,
+      FunctionRegistry.builtin,
       EmptyTableFunctionRegistry) {
       override def createDatabase(dbDefinition: CatalogDatabase, 
ignoreIfExists: Boolean): Unit = {}
     })) {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to