This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new ec5d5471856 [SPARK-41349][CONNECT] Implement DataFrame.hint
ec5d5471856 is described below
commit ec5d547185645126dee87470835ea1d55936dcd0
Author: dengziming <[email protected]>
AuthorDate: Wed Dec 7 16:08:09 2022 +0800
[SPARK-41349][CONNECT] Implement DataFrame.hint
### What changes were proposed in this pull request?
1. Implement `DataFrame.hint` for scala API
2. Implement `DataFrame.hint` for python API
### Why are the changes needed?
API coverage
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
UT
Closes #38899 from dengziming/SPARK-41349.
Authored-by: dengziming <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../org/apache/spark/sql/connect/dsl/package.scala | 33 +++--
.../planner/LiteralValueProtoConverter.scala | 157 +++++++++++++++++++++
.../sql/connect/planner/SparkConnectPlanner.scala | 121 ++--------------
.../planner/LiteralValueProtoConverterSuite.scala | 32 +++++
.../connect/planner/SparkConnectPlannerSuite.scala | 46 ++++++
.../connect/planner/SparkConnectProtoSuite.scala | 4 +
6 files changed, 270 insertions(+), 123 deletions(-)
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
index 8b1d69e03db..ec2d0cad95b 100644
---
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
@@ -26,6 +26,7 @@ import org.apache.spark.connect.proto.Join.JoinType
import org.apache.spark.connect.proto.SetOperation.SetOpType
import org.apache.spark.sql.SaveMode
import org.apache.spark.sql.connect.planner.DataTypeProtoConverter
+import
org.apache.spark.sql.connect.planner.LiteralValueProtoConverter.toConnectProtoValue
/**
* A collection of implicit conversions that create a DSL for constructing
connect protos.
@@ -241,16 +242,6 @@ package object dsl {
implicit class DslNAFunctions(val logicalPlan: Relation) {
- private def convertValue(value: Any) = {
- value match {
- case b: Boolean =>
Expression.Literal.newBuilder().setBoolean(b).build()
- case l: Long => Expression.Literal.newBuilder().setLong(l).build()
- case d: Double =>
Expression.Literal.newBuilder().setDouble(d).build()
- case s: String =>
Expression.Literal.newBuilder().setString(s).build()
- case o => throw new Exception(s"Unsupported value type: $o")
- }
- }
-
def fillValue(value: Any): Relation = {
Relation
.newBuilder()
@@ -258,7 +249,7 @@ package object dsl {
proto.NAFill
.newBuilder()
.setInput(logicalPlan)
- .addAllValues(Seq(convertValue(value)).asJava)
+ .addAllValues(Seq(toConnectProtoValue(value)).asJava)
.build())
.build()
}
@@ -271,13 +262,13 @@ package object dsl {
.newBuilder()
.setInput(logicalPlan)
.addAllCols(cols.toSeq.asJava)
- .addAllValues(Seq(convertValue(value)).asJava)
+ .addAllValues(Seq(toConnectProtoValue(value)).asJava)
.build())
.build()
}
def fillValueMap(valueMap: Map[String, Any]): Relation = {
- val (cols, values) = valueMap.mapValues(convertValue).toSeq.unzip
+ val (cols, values) =
valueMap.mapValues(toConnectProtoValue).toSeq.unzip
Relation
.newBuilder()
.setFillNa(
@@ -338,8 +329,8 @@ package object dsl {
replace.addReplacements(
proto.NAReplace.Replacement
.newBuilder()
- .setOldValue(convertValue(oldValue))
- .setNewValue(convertValue(newValue)))
+ .setOldValue(toConnectProtoValue(oldValue))
+ .setNewValue(toConnectProtoValue(newValue)))
}
Relation
@@ -694,6 +685,18 @@ package object dsl {
.build()
}
+ def hint(name: String, parameters: Any*): Relation = {
+ Relation
+ .newBuilder()
+ .setHint(
+ Hint
+ .newBuilder()
+ .setInput(logicalPlan)
+ .setName(name)
+ .addAllParameters(parameters.map(toConnectProtoValue).asJava))
+ .build()
+ }
+
private def createSetOperation(
left: Relation,
right: Relation,
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/LiteralValueProtoConverter.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/LiteralValueProtoConverter.scala
new file mode 100644
index 00000000000..5a54ad9ac64
--- /dev/null
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/LiteralValueProtoConverter.scala
@@ -0,0 +1,157 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connect.planner
+
+import scala.collection.JavaConverters._
+
+import org.apache.spark.connect.proto
+import org.apache.spark.sql.catalyst.{expressions, InternalRow}
+import org.apache.spark.sql.catalyst.expressions.{CreateArray, CreateMap,
CreateStruct}
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
+
+object LiteralValueProtoConverter {
+
+ /**
+ * Transforms the protocol buffers literals into the appropriate Catalyst
literal expression.
+ *
+ * @return
+ * Expression
+ */
+ def toCatalystExpression(lit: proto.Expression.Literal):
expressions.Expression = {
+ lit.getLiteralTypeCase match {
+ case proto.Expression.Literal.LiteralTypeCase.NULL =>
+ expressions.Literal(null, NullType)
+
+ case proto.Expression.Literal.LiteralTypeCase.BINARY =>
+ expressions.Literal(lit.getBinary.toByteArray, BinaryType)
+
+ case proto.Expression.Literal.LiteralTypeCase.BOOLEAN =>
+ expressions.Literal(lit.getBoolean, BooleanType)
+
+ case proto.Expression.Literal.LiteralTypeCase.BYTE =>
+ expressions.Literal(lit.getByte.toByte, ByteType)
+
+ case proto.Expression.Literal.LiteralTypeCase.SHORT =>
+ expressions.Literal(lit.getShort.toShort, ShortType)
+
+ case proto.Expression.Literal.LiteralTypeCase.INTEGER =>
+ expressions.Literal(lit.getInteger, IntegerType)
+
+ case proto.Expression.Literal.LiteralTypeCase.LONG =>
+ expressions.Literal(lit.getLong, LongType)
+
+ case proto.Expression.Literal.LiteralTypeCase.FLOAT =>
+ expressions.Literal(lit.getFloat, FloatType)
+
+ case proto.Expression.Literal.LiteralTypeCase.DOUBLE =>
+ expressions.Literal(lit.getDouble, DoubleType)
+
+ case proto.Expression.Literal.LiteralTypeCase.DECIMAL =>
+ val decimal = Decimal.apply(lit.getDecimal.getValue)
+ var precision = decimal.precision
+ if (lit.getDecimal.hasPrecision) {
+ precision = math.max(precision, lit.getDecimal.getPrecision)
+ }
+ var scale = decimal.scale
+ if (lit.getDecimal.hasScale) {
+ scale = math.max(scale, lit.getDecimal.getScale)
+ }
+ expressions.Literal(decimal, DecimalType(math.max(precision, scale),
scale))
+
+ case proto.Expression.Literal.LiteralTypeCase.STRING =>
+ expressions.Literal(UTF8String.fromString(lit.getString), StringType)
+
+ case proto.Expression.Literal.LiteralTypeCase.DATE =>
+ expressions.Literal(lit.getDate, DateType)
+
+ case proto.Expression.Literal.LiteralTypeCase.TIMESTAMP =>
+ expressions.Literal(lit.getTimestamp, TimestampType)
+
+ case proto.Expression.Literal.LiteralTypeCase.TIMESTAMP_NTZ =>
+ expressions.Literal(lit.getTimestampNtz, TimestampNTZType)
+
+ case proto.Expression.Literal.LiteralTypeCase.CALENDAR_INTERVAL =>
+ val interval = new CalendarInterval(
+ lit.getCalendarInterval.getMonths,
+ lit.getCalendarInterval.getDays,
+ lit.getCalendarInterval.getMicroseconds)
+ expressions.Literal(interval, CalendarIntervalType)
+
+ case proto.Expression.Literal.LiteralTypeCase.YEAR_MONTH_INTERVAL =>
+ expressions.Literal(lit.getYearMonthInterval, YearMonthIntervalType())
+
+ case proto.Expression.Literal.LiteralTypeCase.DAY_TIME_INTERVAL =>
+ expressions.Literal(lit.getDayTimeInterval, DayTimeIntervalType())
+
+ case proto.Expression.Literal.LiteralTypeCase.ARRAY =>
+ val literals =
lit.getArray.getValuesList.asScala.toArray.map(toCatalystExpression)
+ CreateArray(literals)
+
+ case proto.Expression.Literal.LiteralTypeCase.STRUCT =>
+ val literals =
lit.getStruct.getFieldsList.asScala.toArray.map(toCatalystExpression)
+ CreateStruct(literals)
+
+ case proto.Expression.Literal.LiteralTypeCase.MAP =>
+ val literals = lit.getMap.getPairsList.asScala.toArray.flatMap { pair
=>
+ toCatalystExpression(pair.getKey) ::
toCatalystExpression(pair.getValue) :: Nil
+ }
+ CreateMap(literals)
+
+ case _ =>
+ throw InvalidPlanInput(
+ s"Unsupported Literal Type: ${lit.getLiteralTypeCase.getNumber}" +
+ s"(${lit.getLiteralTypeCase.name})")
+ }
+ }
+
+ def toCatalystValue(lit: proto.Expression.Literal): Any = {
+ lit.getLiteralTypeCase match {
+ case proto.Expression.Literal.LiteralTypeCase.ARRAY =>
+ lit.getArray.getValuesList.asScala.toArray.map(toCatalystValue)
+
+ case proto.Expression.Literal.LiteralTypeCase.STRUCT =>
+ val literals =
lit.getStruct.getFieldsList.asScala.map(toCatalystValue).toSeq
+ InternalRow(literals: _*)
+
+ case proto.Expression.Literal.LiteralTypeCase.MAP =>
+ lit.getMap.getPairsList.asScala.toArray.map { pair =>
+ toCatalystValue(pair.getKey) -> toCatalystValue(pair.getValue)
+ }.toMap
+
+ case proto.Expression.Literal.LiteralTypeCase.STRING => lit.getString
+
+ case _ =>
toCatalystExpression(lit).asInstanceOf[expressions.Literal].value
+ }
+ }
+
+ def toConnectProtoValue(value: Any): proto.Expression.Literal = {
+ value match {
+ case null => proto.Expression.Literal.newBuilder().setNull(true).build()
+ case b: Boolean =>
proto.Expression.Literal.newBuilder().setBoolean(b).build()
+ case b: Byte => proto.Expression.Literal.newBuilder().setByte(b).build()
+ case s: Short =>
proto.Expression.Literal.newBuilder().setShort(s).build()
+ case i: Int =>
proto.Expression.Literal.newBuilder().setInteger(i).build()
+ case l: Long => proto.Expression.Literal.newBuilder().setLong(l).build()
+ case f: Float =>
proto.Expression.Literal.newBuilder().setFloat(f).build()
+ case d: Double =>
proto.Expression.Literal.newBuilder().setDouble(d).build()
+ case s: String =>
proto.Expression.Literal.newBuilder().setString(s).build()
+ case o => throw new Exception(s"Unsupported value type: $o")
+ }
+ }
+}
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 982f9188e1d..d8b7843fbe7 100644
---
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -33,15 +33,15 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.optimizer.CombineUnions
import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParseException}
import org.apache.spark.sql.catalyst.plans.{logical, Cross, FullOuter, Inner,
JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter, UsingJoin}
-import org.apache.spark.sql.catalyst.plans.logical.{Deduplicate, Except,
Intersect, LocalRelation, LogicalPlan, Sample, SubqueryAlias, Union}
+import org.apache.spark.sql.catalyst.plans.logical.{Deduplicate, Except,
Intersect, LocalRelation, LogicalPlan, Sample, SubqueryAlias, Union,
UnresolvedHint}
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
+import
org.apache.spark.sql.connect.planner.LiteralValueProtoConverter.{toCatalystExpression,
toCatalystValue}
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.execution.arrow.ArrowConverters
import org.apache.spark.sql.execution.command.CreateViewCommand
import org.apache.spark.sql.execution.python.UserDefinedPythonFunction
import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
import org.apache.spark.util.Utils
final case class InvalidPlanInput(
@@ -93,6 +93,7 @@ class SparkConnectPlanner(session: SparkSession) {
case proto.Relation.RelTypeCase.RENAME_COLUMNS_BY_NAME_TO_NAME_MAP =>
transformRenameColumnsByNameToNameMap(rel.getRenameColumnsByNameToNameMap)
case proto.Relation.RelTypeCase.WITH_COLUMNS =>
transformWithColumns(rel.getWithColumns)
+ case proto.Relation.RelTypeCase.HINT => transformHint(rel.getHint)
case proto.Relation.RelTypeCase.RELTYPE_NOT_SET =>
throw new IndexOutOfBoundsException("Expected Relation to be set, but
is empty.")
case _ => throw InvalidPlanInput(s"${rel.getUnknown} not supported.")
@@ -199,17 +200,7 @@ class SparkConnectPlanner(session: SparkSession) {
} else {
val valueMap = mutable.Map.empty[String, Any]
cols.zip(values).foreach { case (col, value) =>
- value.getLiteralTypeCase match {
- case proto.Expression.Literal.LiteralTypeCase.BOOLEAN =>
- valueMap.update(col, value.getBoolean)
- case proto.Expression.Literal.LiteralTypeCase.LONG =>
- valueMap.update(col, value.getLong)
- case proto.Expression.Literal.LiteralTypeCase.DOUBLE =>
- valueMap.update(col, value.getDouble)
- case proto.Expression.Literal.LiteralTypeCase.STRING =>
- valueMap.update(col, value.getString)
- case other => throw InvalidPlanInput(s"Unsupported value type:
$other")
- }
+ valueMap.update(col, toCatalystValue(value))
}
dataset.na.fill(valueMap = valueMap.toMap).logicalPlan
}
@@ -233,19 +224,11 @@ class SparkConnectPlanner(session: SparkSession) {
}
private def transformReplace(rel: proto.NAReplace): LogicalPlan = {
- def convert(value: proto.Expression.Literal): Any = {
- value.getLiteralTypeCase match {
- case proto.Expression.Literal.LiteralTypeCase.NULL => null
- case proto.Expression.Literal.LiteralTypeCase.BOOLEAN =>
value.getBoolean
- case proto.Expression.Literal.LiteralTypeCase.DOUBLE => value.getDouble
- case proto.Expression.Literal.LiteralTypeCase.STRING => value.getString
- case other => throw InvalidPlanInput(s"Unsupported value type: $other")
- }
- }
-
val replacement = mutable.Map.empty[Any, Any]
rel.getReplacementsList.asScala.foreach { replace =>
- replacement.update(convert(replace.getOldValue),
convert(replace.getNewValue))
+ replacement.update(
+ toCatalystValue(replace.getOldValue),
+ toCatalystValue(replace.getNewValue))
}
if (rel.getColsCount == 0) {
@@ -313,6 +296,11 @@ class SparkConnectPlanner(session: SparkSession) {
.logicalPlan
}
+ private def transformHint(rel: proto.Hint): LogicalPlan = {
+ val params = rel.getParametersList.asScala.map(toCatalystValue).toSeq
+ UnresolvedHint(rel.getName, params, transformRelation(rel.getInput))
+ }
+
private def transformDeduplicate(rel: proto.Deduplicate): LogicalPlan = {
if (!rel.hasInput) {
throw InvalidPlanInput("Deduplicate needs a plan input")
@@ -426,90 +414,7 @@ class SparkConnectPlanner(session: SparkSession) {
* Expression
*/
private def transformLiteral(lit: proto.Expression.Literal): Expression = {
- lit.getLiteralTypeCase match {
- case proto.Expression.Literal.LiteralTypeCase.NULL =>
- expressions.Literal(null, NullType)
-
- case proto.Expression.Literal.LiteralTypeCase.BINARY =>
- expressions.Literal(lit.getBinary.toByteArray, BinaryType)
-
- case proto.Expression.Literal.LiteralTypeCase.BOOLEAN =>
- expressions.Literal(lit.getBoolean, BooleanType)
-
- case proto.Expression.Literal.LiteralTypeCase.BYTE =>
- expressions.Literal(lit.getByte, ByteType)
-
- case proto.Expression.Literal.LiteralTypeCase.SHORT =>
- expressions.Literal(lit.getShort, ShortType)
-
- case proto.Expression.Literal.LiteralTypeCase.INTEGER =>
- expressions.Literal(lit.getInteger, IntegerType)
-
- case proto.Expression.Literal.LiteralTypeCase.LONG =>
- expressions.Literal(lit.getLong, LongType)
-
- case proto.Expression.Literal.LiteralTypeCase.FLOAT =>
- expressions.Literal(lit.getFloat, FloatType)
-
- case proto.Expression.Literal.LiteralTypeCase.DOUBLE =>
- expressions.Literal(lit.getDouble, DoubleType)
-
- case proto.Expression.Literal.LiteralTypeCase.DECIMAL =>
- val decimal = Decimal.apply(lit.getDecimal.getValue)
- var precision = decimal.precision
- if (lit.getDecimal.hasPrecision) {
- precision = math.max(precision, lit.getDecimal.getPrecision)
- }
- var scale = decimal.scale
- if (lit.getDecimal.hasScale) {
- scale = math.max(scale, lit.getDecimal.getScale)
- }
- expressions.Literal(decimal, DecimalType(math.max(precision, scale),
scale))
-
- case proto.Expression.Literal.LiteralTypeCase.STRING =>
- expressions.Literal(UTF8String.fromString(lit.getString), StringType)
-
- case proto.Expression.Literal.LiteralTypeCase.DATE =>
- expressions.Literal(lit.getDate, DateType)
-
- case proto.Expression.Literal.LiteralTypeCase.TIMESTAMP =>
- expressions.Literal(lit.getTimestamp, TimestampType)
-
- case proto.Expression.Literal.LiteralTypeCase.TIMESTAMP_NTZ =>
- expressions.Literal(lit.getTimestampNtz, TimestampNTZType)
-
- case proto.Expression.Literal.LiteralTypeCase.CALENDAR_INTERVAL =>
- val interval = new CalendarInterval(
- lit.getCalendarInterval.getMonths,
- lit.getCalendarInterval.getDays,
- lit.getCalendarInterval.getMicroseconds)
- expressions.Literal(interval, CalendarIntervalType)
-
- case proto.Expression.Literal.LiteralTypeCase.YEAR_MONTH_INTERVAL =>
- expressions.Literal(lit.getYearMonthInterval, YearMonthIntervalType())
-
- case proto.Expression.Literal.LiteralTypeCase.DAY_TIME_INTERVAL =>
- expressions.Literal(lit.getDayTimeInterval, DayTimeIntervalType())
-
- case proto.Expression.Literal.LiteralTypeCase.ARRAY =>
- val literals =
lit.getArray.getValuesList.asScala.toArray.map(transformLiteral)
- CreateArray(literals)
-
- case proto.Expression.Literal.LiteralTypeCase.STRUCT =>
- val literals =
lit.getStruct.getFieldsList.asScala.toArray.map(transformLiteral)
- CreateStruct(literals)
-
- case proto.Expression.Literal.LiteralTypeCase.MAP =>
- val literals = lit.getMap.getPairsList.asScala.toArray.flatMap { pair
=>
- transformLiteral(pair.getKey) :: transformLiteral(pair.getValue) ::
Nil
- }
- CreateMap(literals)
-
- case _ =>
- throw InvalidPlanInput(
- s"Unsupported Literal Type: ${lit.getLiteralTypeCase.getNumber}" +
- s"(${lit.getLiteralTypeCase.name})")
- }
+ toCatalystExpression(lit)
}
private def transformLimit(limit: proto.Limit): LogicalPlan = {
diff --git
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/LiteralValueProtoConverterSuite.scala
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/LiteralValueProtoConverterSuite.scala
new file mode 100644
index 00000000000..dc8254c47f3
--- /dev/null
+++
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/LiteralValueProtoConverterSuite.scala
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connect.planner
+
+import org.scalatest.funsuite.AnyFunSuite
+
+import
org.apache.spark.sql.connect.planner.LiteralValueProtoConverter.{toCatalystValue,
toConnectProtoValue}
+
+class LiteralValueProtoConverterSuite extends AnyFunSuite {
+
+ test("basic proto value and catalyst value conversion") {
+ val values = Array(null, true, 1.toByte, 1.toShort, 1, 1L, 1.1d, 1.1f,
"spark")
+ for (v <- values) {
+ assertResult(v)(toCatalystValue(toConnectProtoValue(v)))
+ }
+ }
+}
diff --git
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
index 362973a90ef..5362453da50 100644
---
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
+++
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
@@ -28,6 +28,7 @@ import org.apache.spark.sql.{AnalysisException, Dataset}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{AttributeReference,
UnsafeProjection}
import org.apache.spark.sql.catalyst.plans.logical
+import
org.apache.spark.sql.connect.planner.LiteralValueProtoConverter.toConnectProtoValue
import org.apache.spark.sql.execution.arrow.ArrowConverters
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.{IntegerType, StringType, StructField,
StructType}
@@ -571,4 +572,49 @@ class SparkConnectPlannerSuite extends SparkFunSuite with
SparkConnectPlanTest {
Dataset.ofRows(spark,
transform(proto.Relation.newBuilder.setProject(project).build()))
assert(df.schema.fields.toSeq.map(_.name) == Seq("id"))
}
+
+ test("Hint") {
+ val input = proto.Relation
+ .newBuilder()
+ .setSql(
+ proto.SQL
+ .newBuilder()
+ .setQuery("select id from range(10)")
+ .build())
+
+ val logical = transform(
+ proto.Relation
+ .newBuilder()
+ .setHint(
+ proto.Hint
+ .newBuilder()
+ .setInput(input)
+ .setName("REPARTITION")
+ .addParameters(toConnectProtoValue(10000)))
+ .build())
+
+ val df = Dataset.ofRows(spark, logical)
+ assert(df.rdd.partitions.length == 10000)
+ }
+
+ test("Hint with illegal name will be ignored") {
+ val input = proto.Relation
+ .newBuilder()
+ .setSql(
+ proto.SQL
+ .newBuilder()
+ .setQuery("select id from range(10)")
+ .build())
+
+ val logical = transform(
+ proto.Relation
+ .newBuilder()
+ .setHint(
+ proto.Hint
+ .newBuilder()
+ .setInput(input)
+ .setName("illegal"))
+ .build())
+ assert(10 === Dataset.ofRows(spark, logical).count())
+ }
}
diff --git
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
index 1e4e18c3c8f..074372b6c8d 100644
---
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
+++
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
@@ -546,6 +546,10 @@ class SparkConnectProtoSuite extends PlanTest with
SparkConnectPlanTest {
sparkTestRelation.select(col("id").cast(StringType)))
}
+ test("Test Hint") {
+ comparePlans(connectTestRelation.hint("COALESCE", 3),
sparkTestRelation.hint("COALESCE", 3))
+ }
+
private def createLocalRelationProtoByAttributeReferences(
attrs: Seq[AttributeReference]): proto.Relation = {
val localRelationBuilder = proto.LocalRelation.newBuilder()
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]