This is an automated email from the ASF dual-hosted git repository.
dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 47d2c9ca064e [SPARK-49712][SQL] Remove encoderFor from
connect-client-jvm
47d2c9ca064e is described below
commit 47d2c9ca064e9d80a444d21cfac47ca334230242
Author: Herman van Hovell <[email protected]>
AuthorDate: Sat Sep 28 16:27:13 2024 -0700
[SPARK-49712][SQL] Remove encoderFor from connect-client-jvm
### What changes were proposed in this pull request?
This PR removes `sql.encoderFor` from the connect-client-jvm module and
replaces it by `AgnosticEncoders.agnosticEncoderFor`.
### Why are the changes needed?
It will cause a clash when we swap the interface and the implementation.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Existing tests.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #48266 from hvanhovell/SPARK-49712.
Authored-by: Herman van Hovell <[email protected]>
Signed-off-by: Dongjoon Hyun <[email protected]>
---
.../jvm/src/main/scala/org/apache/spark/sql/Dataset.scala | 10 +++++-----
.../org/apache/spark/sql/KeyValueGroupedDataset.scala | 14 +++++++-------
.../org/apache/spark/sql/RelationalGroupedDataset.scala | 7 ++++++-
.../src/main/scala/org/apache/spark/sql/SparkSession.scala | 4 ++--
.../org/apache/spark/sql/internal/UdfToProtoUtils.scala | 10 +++++-----
.../jvm/src/main/scala/org/apache/spark/sql/package.scala | 6 ------
.../scala/org/apache/spark/sql/SQLImplicitsTestSuite.scala | 3 ++-
.../spark/sql/connect/client/arrow/ArrowEncoderSuite.scala | 8 ++++----
8 files changed, 31 insertions(+), 31 deletions(-)
diff --git
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
index d2877ccaf06c..6bae04ef8023 100644
---
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
+++
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -143,7 +143,7 @@ class Dataset[T] private[sql] (
// Make sure we don't forget to set plan id.
assert(plan.getRoot.getCommon.hasPlanId)
- private[sql] val agnosticEncoder: AgnosticEncoder[T] = encoderFor(encoder)
+ private[sql] val agnosticEncoder: AgnosticEncoder[T] =
agnosticEncoderFor(encoder)
override def toString: String = {
try {
@@ -437,7 +437,7 @@ class Dataset[T] private[sql] (
/** @inheritdoc */
def select[U1](c1: TypedColumn[T, U1]): Dataset[U1] = {
- val encoder = encoderFor(c1.encoder)
+ val encoder = agnosticEncoderFor(c1.encoder)
val col = if (encoder.schema == encoder.dataType) {
functions.inline(functions.array(c1))
} else {
@@ -452,7 +452,7 @@ class Dataset[T] private[sql] (
/** @inheritdoc */
protected def selectUntyped(columns: TypedColumn[_, _]*): Dataset[_] = {
- val encoder = ProductEncoder.tuple(columns.map(c => encoderFor(c.encoder)))
+ val encoder = ProductEncoder.tuple(columns.map(c =>
agnosticEncoderFor(c.encoder)))
selectUntyped(encoder, columns)
}
@@ -526,7 +526,7 @@ class Dataset[T] private[sql] (
/** @inheritdoc */
def groupByKey[K: Encoder](func: T => K): KeyValueGroupedDataset[K, T] = {
- KeyValueGroupedDatasetImpl[K, T](this, encoderFor[K], func)
+ KeyValueGroupedDatasetImpl[K, T](this, agnosticEncoderFor[K], func)
}
/** @inheritdoc */
@@ -881,7 +881,7 @@ class Dataset[T] private[sql] (
/** @inheritdoc */
def mapPartitions[U: Encoder](func: Iterator[T] => Iterator[U]): Dataset[U]
= {
- val outputEncoder = encoderFor[U]
+ val outputEncoder = agnosticEncoderFor[U]
val udf = SparkUserDefinedFunction(
function = func,
inputEncoders = agnosticEncoder :: Nil,
diff --git
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
index 6bf251890147..63b5f27c4745 100644
---
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
+++
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
@@ -25,7 +25,7 @@ import scala.jdk.CollectionConverters._
import org.apache.spark.api.java.function._
import org.apache.spark.connect.proto
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
-import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.ProductEncoder
+import
org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{agnosticEncoderFor,
ProductEncoder}
import org.apache.spark.sql.connect.ConnectConversions._
import org.apache.spark.sql.connect.common.UdfUtils
import org.apache.spark.sql.expressions.SparkUserDefinedFunction
@@ -398,7 +398,7 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV](
new KeyValueGroupedDatasetImpl[L, V, IK, IV](
sparkSession,
plan,
- encoderFor[L],
+ agnosticEncoderFor[L],
ivEncoder,
vEncoder,
groupingExprs,
@@ -412,7 +412,7 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV](
plan,
kEncoder,
ivEncoder,
- encoderFor[W],
+ agnosticEncoderFor[W],
groupingExprs,
valueMapFunc
.map(_.andThen(valueFunc))
@@ -430,7 +430,7 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV](
f: (K, Iterator[V]) => IterableOnce[U]): Dataset[U] = {
// Apply mapValues changes to the udf
val nf = UDFAdaptors.flatMapGroupsWithMappedValues(f, valueMapFunc)
- val outputEncoder = encoderFor[U]
+ val outputEncoder = agnosticEncoderFor[U]
sparkSession.newDataset[U](outputEncoder) { builder =>
builder.getGroupMapBuilder
.setInput(plan.getRoot)
@@ -446,7 +446,7 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV](
val otherImpl = other.asInstanceOf[KeyValueGroupedDatasetImpl[K, U, _,
Any]]
// Apply mapValues changes to the udf
val nf = UDFAdaptors.coGroupWithMappedValues(f, valueMapFunc,
otherImpl.valueMapFunc)
- val outputEncoder = encoderFor[R]
+ val outputEncoder = agnosticEncoderFor[R]
sparkSession.newDataset[R](outputEncoder) { builder =>
builder.getCoGroupMapBuilder
.setInput(plan.getRoot)
@@ -461,7 +461,7 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV](
override protected def aggUntyped(columns: TypedColumn[_, _]*): Dataset[_] =
{
// TODO(SPARK-43415): For each column, apply the valueMap func first...
- val rEnc = ProductEncoder.tuple(kEncoder +: columns.map(c =>
encoderFor(c.encoder)))
+ val rEnc = ProductEncoder.tuple(kEncoder +: columns.map(c =>
agnosticEncoderFor(c.encoder)))
sparkSession.newDataset(rEnc) { builder =>
builder.getAggregateBuilder
.setInput(plan.getRoot)
@@ -501,7 +501,7 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV](
null
}
- val outputEncoder = encoderFor[U]
+ val outputEncoder = agnosticEncoderFor[U]
val nf = UDFAdaptors.flatMapGroupsWithStateWithMappedValues(func,
valueMapFunc)
sparkSession.newDataset[U](outputEncoder) { builder =>
diff --git
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
index 14ceb3f4bb14..5bded40b0d13 100644
---
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
+++
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql
import scala.jdk.CollectionConverters._
import org.apache.spark.connect.proto
+import
org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.agnosticEncoderFor
import org.apache.spark.sql.connect.ConnectConversions._
/**
@@ -82,7 +83,11 @@ class RelationalGroupedDataset private[sql] (
/** @inheritdoc */
def as[K: Encoder, T: Encoder]: KeyValueGroupedDataset[K, T] = {
- KeyValueGroupedDatasetImpl[K, T](df, encoderFor[K], encoderFor[T],
groupingExprs)
+ KeyValueGroupedDatasetImpl[K, T](
+ df,
+ agnosticEncoderFor[K],
+ agnosticEncoderFor[T],
+ groupingExprs)
}
/** @inheritdoc */
diff --git
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
index b31670c1da57..222b5ea79508 100644
---
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -36,7 +36,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalog.Catalog
import org.apache.spark.sql.catalyst.{JavaTypeInference, ScalaReflection}
import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, RowEncoder}
-import
org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BoxedLongEncoder,
UnboundRowEncoder}
+import
org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{agnosticEncoderFor,
BoxedLongEncoder, UnboundRowEncoder}
import org.apache.spark.sql.connect.client.{ClassFinder, CloseableIterator,
SparkConnectClient, SparkResult}
import org.apache.spark.sql.connect.client.SparkConnectClient.Configuration
import org.apache.spark.sql.connect.client.arrow.ArrowSerializer
@@ -136,7 +136,7 @@ class SparkSession private[sql] (
/** @inheritdoc */
def createDataset[T: Encoder](data: Seq[T]): Dataset[T] = {
- createDataset(encoderFor[T], data.iterator)
+ createDataset(agnosticEncoderFor[T], data.iterator)
}
/** @inheritdoc */
diff --git
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/UdfToProtoUtils.scala
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/UdfToProtoUtils.scala
index 85ce2cb82043..409c43f480b8 100644
---
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/UdfToProtoUtils.scala
+++
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/UdfToProtoUtils.scala
@@ -25,9 +25,9 @@ import com.google.protobuf.ByteString
import org.apache.spark.SparkException
import org.apache.spark.connect.proto
import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, RowEncoder}
+import
org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.agnosticEncoderFor
import
org.apache.spark.sql.connect.common.DataTypeProtoConverter.toConnectProtoType
import org.apache.spark.sql.connect.common.UdfPacket
-import org.apache.spark.sql.encoderFor
import org.apache.spark.sql.expressions.{SparkUserDefinedFunction,
UserDefinedAggregator, UserDefinedFunction}
import org.apache.spark.util.{ClosureCleaner, SparkClassUtils, SparkSerDeUtils}
@@ -79,12 +79,12 @@ private[sql] object UdfToProtoUtils {
udf match {
case f: SparkUserDefinedFunction =>
val outputEncoder = f.outputEncoder
- .map(e => encoderFor(e))
+ .map(e => agnosticEncoderFor(e))
.getOrElse(RowEncoder.encoderForDataType(f.dataType, lenient =
false))
val inputEncoders = if (f.inputEncoders.forall(_.isEmpty)) {
Nil // Java UDFs have no bindings for their inputs.
} else {
- f.inputEncoders.map(e => encoderFor(e.get)) // TODO support Any and
UnboundRow.
+ f.inputEncoders.map(e => agnosticEncoderFor(e.get)) // TODO support
Any and UnboundRow.
}
inputEncoders.foreach(e =>
protoUdf.addInputTypes(toConnectProtoType(e.dataType)))
protoUdf
@@ -93,8 +93,8 @@ private[sql] object UdfToProtoUtils {
.setAggregate(false)
f.givenName.foreach(invokeUdf.setFunctionName)
case f: UserDefinedAggregator[_, _, _] =>
- val outputEncoder = encoderFor(f.aggregator.outputEncoder)
- val inputEncoder = encoderFor(f.inputEncoder)
+ val outputEncoder = agnosticEncoderFor(f.aggregator.outputEncoder)
+ val inputEncoder = agnosticEncoderFor(f.inputEncoder)
protoUdf
.setPayload(toUdfPacketBytes(f.aggregator, inputEncoder :: Nil,
outputEncoder))
.addInputTypes(toConnectProtoType(inputEncoder.dataType))
diff --git
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/package.scala
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/package.scala
index 556b472283a3..ada94b76fcbc 100644
---
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/package.scala
+++
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/package.scala
@@ -17,12 +17,6 @@
package org.apache.spark
-import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
-
package object sql {
type DataFrame = Dataset[Row]
-
- private[sql] def encoderFor[E: Encoder]: AgnosticEncoder[E] = {
- implicitly[Encoder[E]].asInstanceOf[AgnosticEncoder[E]]
- }
}
diff --git
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SQLImplicitsTestSuite.scala
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SQLImplicitsTestSuite.scala
index 57342e12fcb5..b3b8020b1e4c 100644
---
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SQLImplicitsTestSuite.scala
+++
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SQLImplicitsTestSuite.scala
@@ -26,6 +26,7 @@ import org.apache.arrow.memory.RootAllocator
import org.apache.commons.lang3.SystemUtils
import org.scalatest.BeforeAndAfterAll
+import
org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.agnosticEncoderFor
import org.apache.spark.sql.connect.client.SparkConnectClient
import org.apache.spark.sql.connect.client.arrow.{ArrowDeserializers,
ArrowSerializer}
import org.apache.spark.sql.test.ConnectFunSuite
@@ -55,7 +56,7 @@ class SQLImplicitsTestSuite extends ConnectFunSuite with
BeforeAndAfterAll {
import org.apache.spark.util.ArrayImplicits._
import spark.implicits._
def testImplicit[T: Encoder](expected: T): Unit = {
- val encoder = encoderFor[T]
+ val encoder = agnosticEncoderFor[T]
val allocator = new RootAllocator()
try {
val batch = ArrowSerializer.serialize(
diff --git
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala
index 5397dae9dcc5..7176c582d0bb 100644
---
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala
+++
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala
@@ -30,11 +30,11 @@ import org.apache.arrow.memory.{BufferAllocator,
RootAllocator}
import org.apache.arrow.vector.VarBinaryVector
import org.scalatest.BeforeAndAfterAll
-import org.apache.spark.{sql, SparkRuntimeException,
SparkUnsupportedOperationException}
+import org.apache.spark.{SparkRuntimeException,
SparkUnsupportedOperationException}
import org.apache.spark.sql.{AnalysisException, Encoders, Row}
import org.apache.spark.sql.catalyst.{DefinedByConstructorParams,
JavaTypeInference, ScalaReflection}
import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, Codec,
OuterScopes}
-import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BinaryEncoder,
BoxedBooleanEncoder, BoxedByteEncoder, BoxedDoubleEncoder, BoxedFloatEncoder,
BoxedIntEncoder, BoxedLongEncoder, BoxedShortEncoder, CalendarIntervalEncoder,
DateEncoder, DayTimeIntervalEncoder, EncoderField, InstantEncoder,
IterableEncoder, JavaDecimalEncoder, LocalDateEncoder, LocalDateTimeEncoder,
NullEncoder, PrimitiveBooleanEncoder, PrimitiveByteEncoder,
PrimitiveDoubleEncoder, PrimitiveFloatEncoder, Primi [...]
+import
org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{agnosticEncoderFor,
BinaryEncoder, BoxedBooleanEncoder, BoxedByteEncoder, BoxedDoubleEncoder,
BoxedFloatEncoder, BoxedIntEncoder, BoxedLongEncoder, BoxedShortEncoder,
CalendarIntervalEncoder, DateEncoder, DayTimeIntervalEncoder, EncoderField,
InstantEncoder, IterableEncoder, JavaDecimalEncoder, LocalDateEncoder,
LocalDateTimeEncoder, NullEncoder, PrimitiveBooleanEncoder,
PrimitiveByteEncoder, PrimitiveDoubleEncoder, Primitiv [...]
import org.apache.spark.sql.catalyst.encoders.RowEncoder.{encoderFor =>
toRowEncoder}
import org.apache.spark.sql.catalyst.util.{DateFormatter, SparkStringUtils,
TimestampFormatter}
import org.apache.spark.sql.catalyst.util.DateTimeConstants.MICROS_PER_SECOND
@@ -770,7 +770,7 @@ class ArrowEncoderSuite extends ConnectFunSuite with
BeforeAndAfterAll {
}
test("java serialization") {
- val encoder = sql.encoderFor(Encoders.javaSerialization[(Int, String)])
+ val encoder = agnosticEncoderFor(Encoders.javaSerialization[(Int, String)])
roundTripAndCheckIdentical(encoder) { () =>
Iterator.tabulate(10)(i => (i, "itr_" + i))
}
@@ -778,7 +778,7 @@ class ArrowEncoderSuite extends ConnectFunSuite with
BeforeAndAfterAll {
test("kryo serialization") {
val e = intercept[SparkRuntimeException] {
- val encoder = sql.encoderFor(Encoders.kryo[(Int, String)])
+ val encoder = agnosticEncoderFor(Encoders.kryo[(Int, String)])
roundTripAndCheckIdentical(encoder) { () =>
Iterator.tabulate(10)(i => (i, "itr_" + i))
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]