amaliujia commented on code in PR #38406:
URL: https://github.com/apache/spark/pull/38406#discussion_r1006297995


##########
connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala:
##########
@@ -30,181 +36,140 @@ import 
org.apache.spark.sql.catalyst.plans.logical.LocalRelation
  */
 class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest {
 
-  lazy val connectTestRelation = createLocalRelationProto(Seq($"id".int, 
$"name".string))
+  lazy val connectTestRelation =
+    createLocalRelationProto(
+      Seq(AttributeReference("id", IntegerType)(), AttributeReference("name", 
StringType)()))
 
-  lazy val connectTestRelation2 = createLocalRelationProto(
-    Seq($"key".int, $"value".int, $"name".string))
+  lazy val connectTestRelation2 =
+    createLocalRelationProto(
+      Seq(AttributeReference("id", IntegerType)(), AttributeReference("name", 
StringType)()))
 
-  lazy val sparkTestRelation: LocalRelation = LocalRelation($"id".int, 
$"name".string)
+  lazy val sparkTestRelation: LocalRelation =
+    LocalRelation(
+      AttributeReference("id", IntegerType)(),
+      AttributeReference("name", StringType)())
 
   lazy val sparkTestRelation2: LocalRelation =
-    LocalRelation($"key".int, $"value".int, $"name".string)
+    LocalRelation(
+      AttributeReference("id", IntegerType)(),
+      AttributeReference("name", StringType)())
 
   test("Basic select") {
-    val connectPlan = {
-      // TODO: Scala only allows one implicit per scope so we keep proto 
implicit imports in
-      // this scope. Need to find a better way to make two implicits work in 
the same scope.
-      import org.apache.spark.sql.connect.dsl.expressions._
-      import org.apache.spark.sql.connect.dsl.plans._
-      transform(connectTestRelation.select("id".protoAttr))
-    }
-    val sparkPlan = sparkTestRelation.select($"id")
-    comparePlans(connectPlan.analyze, sparkPlan.analyze, false)
+    val connectPlan = connectTestRelation.select("id".protoAttr)
+    val sparkPlan = sparkTestRelation.select("id")
+    comparePlans(connectPlan, sparkPlan)
   }
 
   test("UnresolvedFunction resolution.") {
-    {
-      import org.apache.spark.sql.connect.dsl.expressions._
-      import org.apache.spark.sql.connect.dsl.plans._
-      assertThrows[IllegalArgumentException] {
-        transform(connectTestRelation.select(callFunction("default.hex", 
Seq("id".protoAttr))))
-      }
+    assertThrows[IllegalArgumentException] {
+      transform(connectTestRelation.select(callFunction("default.hex", 
Seq("id".protoAttr))))
     }
 
-    val connectPlan = {
-      import org.apache.spark.sql.connect.dsl.expressions._
-      import org.apache.spark.sql.connect.dsl.plans._
-      transform(
-        connectTestRelation.select(callFunction(Seq("default", "hex"), 
Seq("id".protoAttr))))
-    }
+    val connectPlan =
+      connectTestRelation.select(callFunction(Seq("default", "hex"), 
Seq("id".protoAttr)))
 
     assertThrows[UnsupportedOperationException] {
-      connectPlan.analyze
+      analyzePlan(transform(connectPlan))
     }
 
-    val validPlan = {
-      import org.apache.spark.sql.connect.dsl.expressions._
-      import org.apache.spark.sql.connect.dsl.plans._
-      transform(connectTestRelation.select(callFunction(Seq("hex"), 
Seq("id".protoAttr))))
-    }
-    assert(validPlan.analyze != null)
+    val validPlan = connectTestRelation.select(callFunction(Seq("hex"), 
Seq("id".protoAttr)))
+    assert(analyzePlan(transform(validPlan)) != null)
   }
 
   test("Basic filter") {
-    val connectPlan = {
-      import org.apache.spark.sql.connect.dsl.expressions._
-      import org.apache.spark.sql.connect.dsl.plans._
-      transform(connectTestRelation.where("id".protoAttr < 0))
-    }
-
-    val sparkPlan = sparkTestRelation.where($"id" < 0).analyze
-    comparePlans(connectPlan.analyze, sparkPlan.analyze, false)
+    val connectPlan = connectTestRelation.where("id".protoAttr < 0)
+    val sparkPlan = sparkTestRelation.where(Column("id") < 0)
+    comparePlans(connectPlan, sparkPlan)
   }
 
   test("Basic joins with different join types") {
-    val connectPlan = {
-      import org.apache.spark.sql.connect.dsl.plans._
-      transform(connectTestRelation.join(connectTestRelation2))
-    }
+    val connectPlan = connectTestRelation.join(connectTestRelation2)
     val sparkPlan = sparkTestRelation.join(sparkTestRelation2)
-    comparePlans(connectPlan.analyze, sparkPlan.analyze, false)
+    comparePlans(connectPlan, sparkPlan)
+
+    val connectPlan2 = connectTestRelation.join(connectTestRelation2)
+    val sparkPlan2 = sparkTestRelation.join(sparkTestRelation2)
+    comparePlans(connectPlan2, sparkPlan2)
 
-    val connectPlan2 = {
-      import org.apache.spark.sql.connect.dsl.plans._
-      transform(connectTestRelation.join(connectTestRelation2, condition = 
None))
-    }
-    val sparkPlan2 = sparkTestRelation.join(sparkTestRelation2, condition = 
None)
-    comparePlans(connectPlan2.analyze, sparkPlan2.analyze, false)
     for ((t, y) <- Seq(
         (JoinType.JOIN_TYPE_LEFT_OUTER, LeftOuter),
         (JoinType.JOIN_TYPE_RIGHT_OUTER, RightOuter),
         (JoinType.JOIN_TYPE_FULL_OUTER, FullOuter),
         (JoinType.JOIN_TYPE_LEFT_ANTI, LeftAnti),
         (JoinType.JOIN_TYPE_LEFT_SEMI, LeftSemi),
         (JoinType.JOIN_TYPE_INNER, Inner))) {
-      val connectPlan3 = {
-        import org.apache.spark.sql.connect.dsl.plans._
-        transform(connectTestRelation.join(connectTestRelation2, t))
-      }
-      val sparkPlan3 = sparkTestRelation.join(sparkTestRelation2, y)
-      comparePlans(connectPlan3.analyze, sparkPlan3.analyze, false)
-    }
 
-    val connectPlan4 = {
-      import org.apache.spark.sql.connect.dsl.plans._
-      transform(
-        connectTestRelation.join(connectTestRelation2, 
JoinType.JOIN_TYPE_INNER, Seq("name")))
+      val connectPlan3 = connectTestRelation.join(connectTestRelation2, t, 
Seq("id"))
+      val sparkPlan3 = sparkTestRelation.join(sparkTestRelation2, Seq("id"), 
y.toString)
+      comparePlans(connectPlan3, sparkPlan3)
     }
-    val sparkPlan4 = sparkTestRelation.join(sparkTestRelation2, 
UsingJoin(Inner, Seq("name")))
-    comparePlans(connectPlan4.analyze, sparkPlan4.analyze, false)
+
+    val connectPlan4 =
+      connectTestRelation.join(connectTestRelation2, JoinType.JOIN_TYPE_INNER, 
Seq("name"))
+    val sparkPlan4 = sparkTestRelation.join(sparkTestRelation2, Seq("name"), 
Inner.toString)
+    comparePlans(connectPlan4, sparkPlan4)
   }
 
   test("Test sample") {
-    val connectPlan = {
-      import org.apache.spark.sql.connect.dsl.plans._
-      transform(connectTestRelation.sample(0, 0.2, false, 1))
-    }
-    val sparkPlan = sparkTestRelation.sample(0, 0.2, false, 1)
-    comparePlans(connectPlan.analyze, sparkPlan.analyze, false)
+    val connectPlan = connectTestRelation.sample(0, 0.2, false, 1)
+    val sparkPlan = sparkTestRelation.sample(false, 0.2 - 0, 1)
+    comparePlans(connectPlan, sparkPlan)
   }
 
   test("column alias") {
-    val connectPlan = {
-      import org.apache.spark.sql.connect.dsl.expressions._
-      import org.apache.spark.sql.connect.dsl.plans._
-      transform(connectTestRelation.select("id".protoAttr.as("id2")))
-    }
-    val sparkPlan = sparkTestRelation.select($"id".as("id2"))
-    comparePlans(connectPlan.analyze, sparkPlan.analyze, false)
+    val connectPlan = connectTestRelation.select("id".protoAttr.as("id2"))
+    val sparkPlan = sparkTestRelation.select(Column("id").alias("id2"))
+    comparePlans(connectPlan, sparkPlan)
   }
 
-  test("Aggregate with more than 1 grouping expressions") {
-    val connectPlan = {
-      import org.apache.spark.sql.connect.dsl.expressions._
-      import org.apache.spark.sql.connect.dsl.plans._
-      transform(connectTestRelation.groupBy("id".protoAttr, 
"name".protoAttr)())
-    }
-    val sparkPlan = sparkTestRelation.groupBy($"id", $"name")()
-    comparePlans(connectPlan.analyze, sparkPlan.analyze, false)
-  }
+// TODO: improve aggregate API parity.
+//  test("Aggregate with more than 1 grouping expressions") {
+//    val connectPlan =
+//      connectTestRelation.groupBy("id".protoAttr, "name".protoAttr)()
+//    val sparkPlan =
+//      sparkTestRelation.groupBy(Column("id"), 
Column("name")).agg(Map.empty[String, String])
+//    comparePlans(connectPlan, sparkPlan)

Review Comment:
   Current Connect does not be compatible with existing DataFrame API on 
grouping/aggregation. 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to