This is an automated email from the ASF dual-hosted git repository. ruifengz pushed a commit to branch branch-3.4 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.4 by this push: new def02cb9da3 [SPARK-42755][CONNECT] Factor literal value conversion out to `connect-common` def02cb9da3 is described below commit def02cb9da38bb2029e2859cd06f517d547f482f Author: Ruifeng Zheng <ruife...@apache.org> AuthorDate: Mon Mar 13 14:29:15 2023 +0800 [SPARK-42755][CONNECT] Factor literal value conversion out to `connect-common` ### What changes were proposed in this pull request? Factor literal value conversion out to `connect-common`. ### Why are the changes needed? when trying to build protos of literal array in the server side for ml, I found we have two implementations: `LiteralExpressionProtoConverter. toConnectProtoValue` in server module, but it doesn't support array; `LiteralProtoConverter. toLiteralProto` in client module, it support more types; We'd better factor it out to common module, so that both client and server can leverage it. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? existing UT Closes #40375 from zhengruifeng/connect_mv_literal_common. Authored-by: Ruifeng Zheng <ruife...@apache.org> Signed-off-by: Ruifeng Zheng <ruife...@apache.org> (cherry picked from commit 43caae31dfa05b3d237acfa3115bd0e7b4e540ed) Signed-off-by: Ruifeng Zheng <ruife...@apache.org> --- .../scala/org/apache/spark/sql/functions.scala | 2 +- .../common/LiteralValueProtoConverter.scala} | 10 +++--- .../org/apache/spark/sql/connect/dsl/package.scala | 14 ++++---- ...scala => LiteralExpressionProtoConverter.scala} | 22 +------------ .../sql/connect/planner/SparkConnectPlanner.scala | 2 +- .../service/SparkConnectStreamHandler.scala | 4 +-- ... => LiteralExpressionProtoConverterSuite.scala} | 7 ++-- .../connect/planner/SparkConnectPlannerSuite.scala | 38 +++++++++------------- .../connect/planner/SparkConnectProtoSuite.scala | 14 ++++---- 9 files changed, 44 insertions(+), 69 deletions(-) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala index 8ce90886e0f..29c2e89c537 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala @@ -23,8 +23,8 @@ import scala.reflect.runtime.universe.{typeTag, TypeTag} import org.apache.spark.connect.proto import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.PrimitiveLongEncoder +import org.apache.spark.sql.connect.common.LiteralValueProtoConverter._ import org.apache.spark.sql.expressions.{ScalarUserDefinedFunction, UserDefinedFunction} -import org.apache.spark.sql.expressions.LiteralProtoConverter._ import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.sql.types.DataType.parseTypeWithFallback diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/LiteralProtoConverter.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala similarity index 95% rename from connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/LiteralProtoConverter.scala rename to connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala index daddfa9b5af..ceef9b21244 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/LiteralProtoConverter.scala +++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala @@ -14,7 +14,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql.expressions + +package org.apache.spark.sql.connect.common import java.lang.{Boolean => JBoolean, Byte => JByte, Character => JChar, Double => JDouble, Float => JFloat, Integer => JInteger, Long => JLong, Short => JShort} import java.math.{BigDecimal => JBigDecimal} @@ -25,12 +26,11 @@ import com.google.protobuf.ByteString import org.apache.spark.connect.proto import org.apache.spark.sql.catalyst.util.{DateTimeUtils, IntervalUtils} -import org.apache.spark.sql.connect.client.unsupported import org.apache.spark.sql.connect.common.DataTypeProtoConverter._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval -object LiteralProtoConverter { +object LiteralValueProtoConverter { private lazy val nullType = proto.DataType.newBuilder().setNull(proto.DataType.NULL.getDefaultInstance).build() @@ -93,7 +93,7 @@ object LiteralProtoConverter { case v: CalendarInterval => builder.setCalendarInterval(calendarIntervalBuilder(v.months, v.days, v.microseconds)) case null => builder.setNull(nullType) - case _ => unsupported(s"literal $literal not supported (yet).") + case _ => throw new UnsupportedOperationException(s"literal $literal not supported (yet).") } } @@ -103,7 +103,7 @@ object LiteralProtoConverter { * @return * proto.Expression.Literal */ - private def toLiteralProto(literal: Any): proto.Expression.Literal = + def toLiteralProto(literal: Any): proto.Expression.Literal = toLiteralProtoBuilder(literal).build() private def toDataType(clz: Class[_]): DataType = clz match { diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala index 7e60c5f9a28..21b9180ccfb 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala @@ -26,8 +26,8 @@ import org.apache.spark.connect.proto.Join.JoinType import org.apache.spark.connect.proto.SetOperation.SetOpType import org.apache.spark.sql.{Observation, SaveMode} import org.apache.spark.sql.connect.common.DataTypeProtoConverter +import org.apache.spark.sql.connect.common.LiteralValueProtoConverter.toLiteralProto import org.apache.spark.sql.connect.planner.{SaveModeConverter, TableSaveMethodConverter} -import org.apache.spark.sql.connect.planner.LiteralValueProtoConverter.toConnectProtoValue import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils @@ -342,7 +342,7 @@ package object dsl { proto.NAFill .newBuilder() .setInput(logicalPlan) - .addAllValues(Seq(toConnectProtoValue(value)).asJava) + .addAllValues(Seq(toLiteralProto(value)).asJava) .build()) .build() } @@ -355,13 +355,13 @@ package object dsl { .newBuilder() .setInput(logicalPlan) .addAllCols(cols.asJava) - .addAllValues(Seq(toConnectProtoValue(value)).asJava) + .addAllValues(Seq(toLiteralProto(value)).asJava) .build()) .build() } def fillValueMap(valueMap: Map[String, Any]): Relation = { - val (cols, values) = valueMap.mapValues(toConnectProtoValue).toSeq.unzip + val (cols, values) = valueMap.mapValues(toLiteralProto).toSeq.unzip Relation .newBuilder() .setFillNa( @@ -422,8 +422,8 @@ package object dsl { replace.addReplacements( proto.NAReplace.Replacement .newBuilder() - .setOldValue(toConnectProtoValue(oldValue)) - .setNewValue(toConnectProtoValue(newValue))) + .setOldValue(toLiteralProto(oldValue)) + .setNewValue(toLiteralProto(newValue))) } Relation @@ -978,7 +978,7 @@ package object dsl { def hint(name: String, parameters: Any*): Relation = { val expressions = parameters.map { parameter => - proto.Expression.newBuilder().setLiteral(toConnectProtoValue(parameter)).build() + proto.Expression.newBuilder().setLiteral(toLiteralProto(parameter)).build() } Relation diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/LiteralValueProtoConverter.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverter.scala similarity index 87% rename from connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/LiteralValueProtoConverter.scala rename to connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverter.scala index 7a580913867..9f2baea5737 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/LiteralValueProtoConverter.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverter.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, InvalidPlanI import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} -object LiteralValueProtoConverter { +object LiteralExpressionProtoConverter { /** * Transforms the protocol buffers literals into the appropriate Catalyst literal expression. @@ -121,25 +121,6 @@ object LiteralValueProtoConverter { } } - def toConnectProtoValue(value: Any): proto.Expression.Literal = { - value match { - case null => - proto.Expression.Literal - .newBuilder() - .setNull(DataTypeProtoConverter.toConnectProtoType(NullType)) - .build() - case b: Boolean => proto.Expression.Literal.newBuilder().setBoolean(b).build() - case b: Byte => proto.Expression.Literal.newBuilder().setByte(b).build() - case s: Short => proto.Expression.Literal.newBuilder().setShort(s).build() - case i: Int => proto.Expression.Literal.newBuilder().setInteger(i).build() - case l: Long => proto.Expression.Literal.newBuilder().setLong(l).build() - case f: Float => proto.Expression.Literal.newBuilder().setFloat(f).build() - case d: Double => proto.Expression.Literal.newBuilder().setDouble(d).build() - case s: String => proto.Expression.Literal.newBuilder().setString(s).build() - case o => throw new Exception(s"Unsupported value type: $o") - } - } - private def toArrayData(array: proto.Expression.Literal.Array): Any = { def makeArrayData[T](converter: proto.Expression.Literal => T)(implicit tag: ClassTag[T]): Array[T] = { @@ -195,5 +176,4 @@ object LiteralValueProtoConverter { throw InvalidPlanInput(s"Unsupported Literal Type: $elementType)") } } - } diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 24717e07b00..a057bd8d6c1 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -42,7 +42,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{CollectMetrics, CommandResul import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CharVarcharUtils} import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, InvalidPlanInput, UdfPacket} import org.apache.spark.sql.connect.config.Connect.CONNECT_GRPC_ARROW_MAX_BATCH_SIZE -import org.apache.spark.sql.connect.planner.LiteralValueProtoConverter.{toCatalystExpression, toCatalystValue} +import org.apache.spark.sql.connect.planner.LiteralExpressionProtoConverter.{toCatalystExpression, toCatalystValue} import org.apache.spark.sql.connect.plugin.SparkConnectPluginRegistry import org.apache.spark.sql.connect.service.SparkConnectStreamHandler import org.apache.spark.sql.errors.QueryCompilationErrors diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala index 0dd1741f099..104d840ed52 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala @@ -28,8 +28,8 @@ import org.apache.spark.connect.proto.{ExecutePlanRequest, ExecutePlanResponse} import org.apache.spark.internal.Logging import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.connect.common.LiteralValueProtoConverter.toLiteralProto import org.apache.spark.sql.connect.config.Connect.CONNECT_GRPC_ARROW_MAX_BATCH_SIZE -import org.apache.spark.sql.connect.planner.LiteralValueProtoConverter.toConnectProtoValue import org.apache.spark.sql.connect.planner.SparkConnectPlanner import org.apache.spark.sql.connect.service.SparkConnectStreamHandler.processAsArrowBatches import org.apache.spark.sql.execution.{SparkPlan, SQLExecution} @@ -216,7 +216,7 @@ object SparkConnectStreamHandler { sessionId: String, dataframe: DataFrame): ExecutePlanResponse = { val observedMetrics = dataframe.queryExecution.observedMetrics.map { case (name, row) => - val cols = (0 until row.length).map(i => toConnectProtoValue(row(i))) + val cols = (0 until row.length).map(i => toLiteralProto(row(i))) ExecutePlanResponse.ObservedMetrics .newBuilder() .setName(name) diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/LiteralValueProtoConverterSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverterSuite.scala similarity index 76% rename from connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/LiteralValueProtoConverterSuite.scala rename to connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverterSuite.scala index 7c8ee6209ac..c3479456617 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/LiteralValueProtoConverterSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverterSuite.scala @@ -19,14 +19,15 @@ package org.apache.spark.sql.connect.planner import org.scalatest.funsuite.AnyFunSuite // scalastyle:ignore funsuite -import org.apache.spark.sql.connect.planner.LiteralValueProtoConverter.{toCatalystValue, toConnectProtoValue} +import org.apache.spark.sql.connect.common.LiteralValueProtoConverter.toLiteralProto +import org.apache.spark.sql.connect.planner.LiteralExpressionProtoConverter.toCatalystValue -class LiteralValueProtoConverterSuite extends AnyFunSuite { // scalastyle:ignore funsuite +class LiteralExpressionProtoConverterSuite extends AnyFunSuite { // scalastyle:ignore funsuite test("basic proto value and catalyst value conversion") { val values = Array(null, true, 1.toByte, 1.toShort, 1, 1L, 1.1d, 1.1f, "spark") for (v <- values) { - assertResult(v)(toCatalystValue(toConnectProtoValue(v))) + assertResult(v)(toCatalystValue(toLiteralProto(v))) } } } diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala index b79d91d2d10..b6b214c839d 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{AttributeReference, UnsafeProjection} import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.connect.common.InvalidPlanInput -import org.apache.spark.sql.connect.planner.LiteralValueProtoConverter.toConnectProtoValue +import org.apache.spark.sql.connect.common.LiteralValueProtoConverter.toLiteralProto import org.apache.spark.sql.execution.arrow.ArrowConverters import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} @@ -602,13 +602,11 @@ class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest { val logical = transform( proto.Relation .newBuilder() - .setHint( - proto.Hint - .newBuilder() - .setInput(input) - .setName("REPARTITION") - .addParameters( - proto.Expression.newBuilder().setLiteral(toConnectProtoValue(10000)).build())) + .setHint(proto.Hint + .newBuilder() + .setInput(input) + .setName("REPARTITION") + .addParameters(proto.Expression.newBuilder().setLiteral(toLiteralProto(10000)).build())) .build()) val df = Dataset.ofRows(spark, logical) @@ -648,13 +646,11 @@ class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest { val logical = transform( proto.Relation .newBuilder() - .setHint( - proto.Hint - .newBuilder() - .setInput(input) - .setName("REPARTITION") - .addParameters( - proto.Expression.newBuilder().setLiteral(toConnectProtoValue("id")).build())) + .setHint(proto.Hint + .newBuilder() + .setInput(input) + .setName("REPARTITION") + .addParameters(proto.Expression.newBuilder().setLiteral(toLiteralProto("id")).build())) .build()) assert(10 === Dataset.ofRows(spark, logical).count()) } @@ -671,13 +667,11 @@ class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest { val logical = transform( proto.Relation .newBuilder() - .setHint( - proto.Hint - .newBuilder() - .setInput(input) - .setName("REPARTITION") - .addParameters( - proto.Expression.newBuilder().setLiteral(toConnectProtoValue(true)).build())) + .setHint(proto.Hint + .newBuilder() + .setInput(input) + .setName("REPARTITION") + .addParameters(proto.Expression.newBuilder().setLiteral(toLiteralProto(true)).build())) .build()) intercept[AnalysisException](Dataset.ofRows(spark, logical)) } diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala index 00ff6ac2fb6..9cc714d630b 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala @@ -32,11 +32,11 @@ import org.apache.spark.sql.catalyst.expressions.{AttributeReference, GenericInt import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, LeftAnti, LeftOuter, LeftSemi, PlanTest, RightOuter} import org.apache.spark.sql.catalyst.plans.logical.{Distinct, LocalRelation, LogicalPlan} import org.apache.spark.sql.connect.common.InvalidPlanInput +import org.apache.spark.sql.connect.common.LiteralValueProtoConverter.toLiteralProto import org.apache.spark.sql.connect.dsl.MockRemoteSession import org.apache.spark.sql.connect.dsl.commands._ import org.apache.spark.sql.connect.dsl.expressions._ import org.apache.spark.sql.connect.dsl.plans._ -import org.apache.spark.sql.connect.planner.LiteralValueProtoConverter.toConnectProtoValue import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryTableCatalog, TableCatalog} import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.CatalogHelper import org.apache.spark.sql.execution.arrow.ArrowConverters @@ -245,7 +245,7 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest { val connectPlan3 = connectTestRelation.rollup("id".protoAttr, "name".protoAttr)( - proto_min(proto.Expression.newBuilder().setLiteral(toConnectProtoValue(1)).build()) + proto_min(proto.Expression.newBuilder().setLiteral(toLiteralProto(1)).build()) .as("agg1")) val sparkPlan3 = sparkTestRelation @@ -269,7 +269,7 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest { val connectPlan3 = connectTestRelation.cube("id".protoAttr, "name".protoAttr)( - proto_min(proto.Expression.newBuilder().setLiteral(toConnectProtoValue(1)).build()) + proto_min(proto.Expression.newBuilder().setLiteral(toLiteralProto(1)).build()) .as("agg1")) val sparkPlan3 = sparkTestRelation @@ -282,8 +282,8 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest { val connectPlan1 = connectTestRelation.pivot("id".protoAttr)( "name".protoAttr, - Seq("a", "b", "c").map(toConnectProtoValue))( - proto_min(proto.Expression.newBuilder().setLiteral(toConnectProtoValue(1)).build()) + Seq("a", "b", "c").map(toLiteralProto))( + proto_min(proto.Expression.newBuilder().setLiteral(toLiteralProto(1)).build()) .as("agg1")) val sparkPlan1 = sparkTestRelation @@ -295,8 +295,8 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest { val connectPlan2 = connectTestRelation.pivot("name".protoAttr)( "id".protoAttr, - Seq(1, 2, 3).map(toConnectProtoValue))( - proto_min(proto.Expression.newBuilder().setLiteral(toConnectProtoValue(1)).build()) + Seq(1, 2, 3).map(toLiteralProto))( + proto_min(proto.Expression.newBuilder().setLiteral(toLiteralProto(1)).build()) .as("agg1")) val sparkPlan2 = sparkTestRelation --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org