This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new d9c90887cb9 [SPARK-40780][CONNECT] Add WHERE to Connect proto and DSL
d9c90887cb9 is described below
commit d9c90887cb9ef32d54b3e0edcfffb43ba3d70fa6
Author: Rui Wang <[email protected]>
AuthorDate: Thu Oct 13 21:22:21 2022 +0800
[SPARK-40780][CONNECT] Add WHERE to Connect proto and DSL
### What changes were proposed in this pull request?
Add WHERE to Connect proto and DSL.
### Why are the changes needed?
Improve Connect proto testing coverage.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
UT
Closes #38232 from amaliujia/add_filter_to_dsl.
Authored-by: Rui Wang <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../org/apache/spark/sql/connect/dsl/package.scala | 22 ++++++++++++++++++++++
.../connect/planner/SparkConnectProtoSuite.scala | 11 +++++++++++
.../spark/sql/catalyst/analysis/Analyzer.scala | 2 +-
3 files changed, 34 insertions(+), 1 deletion(-)
diff --git
a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
index 80d6e77c9fc..0db8ab96610 100644
---
a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
+++
b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.connect
import scala.collection.JavaConverters._
+import scala.language.implicitConversions
import org.apache.spark.connect.proto
import org.apache.spark.connect.proto.Join.JoinType
@@ -44,7 +45,20 @@ package object dsl {
implicit class DslExpression(val expr: proto.Expression) {
def as(alias: String): proto.Expression =
proto.Expression.newBuilder().setAlias(
proto.Expression.Alias.newBuilder().setName(alias).setExpr(expr)).build()
+
+ def < (other: proto.Expression): proto.Expression =
+ proto.Expression.newBuilder().setUnresolvedFunction(
+ proto.Expression.UnresolvedFunction.newBuilder()
+ .addParts("<")
+ .addArguments(expr)
+ .addArguments(other)
+ ).build()
}
+
+ implicit def intToLiteral(i: Int): proto.Expression =
+ proto.Expression.newBuilder().setLiteral(
+ proto.Expression.Literal.newBuilder().setI32(i)
+ ).build()
}
object plans { // scalastyle:ignore
@@ -58,6 +72,14 @@ package object dsl {
).build()
}
+ def where(condition: proto.Expression): proto.Relation = {
+ proto.Relation.newBuilder()
+ .setFilter(
+
proto.Filter.newBuilder().setInput(logicalPlan).setCondition(condition)
+ ).build()
+ }
+
+
def join(
otherPlan: proto.Relation,
joinType: JoinType = JoinType.JOIN_TYPE_INNER,
diff --git
a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
index 510b54cd250..351cc70852a 100644
---
a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
+++
b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
@@ -51,6 +51,17 @@ class SparkConnectProtoSuite extends PlanTest with
SparkConnectPlanTest {
comparePlans(connectPlan.analyze, sparkPlan.analyze, false)
}
+ test("Basic filter") {
+ val connectPlan = {
+ import org.apache.spark.sql.connect.dsl.expressions._
+ import org.apache.spark.sql.connect.dsl.plans._
+ transform(connectTestRelation.where("id".protoAttr < 0))
+ }
+
+ val sparkPlan = sparkTestRelation.where($"id" < 0).analyze
+ comparePlans(connectPlan.analyze, sparkPlan.analyze, false)
+ }
+
test("Basic joins with different join types") {
val connectPlan = {
import org.apache.spark.sql.connect.dsl.plans._
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 62d930dcd20..ae65902e8a6 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -70,7 +70,7 @@ object SimpleAnalyzer extends Analyzer(
FakeV2SessionCatalog,
new SessionCatalog(
new InMemoryCatalog,
- EmptyFunctionRegistry,
+ FunctionRegistry.builtin,
EmptyTableFunctionRegistry) {
override def createDatabase(dbDefinition: CatalogDatabase,
ignoreIfExists: Boolean): Unit = {}
})) {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]