This is an automated email from the ASF dual-hosted git repository.
hvanhovell pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 31cd6991558c [SPARK-49424][CONNECT][SQL] Consolidate Encoders.scala
31cd6991558c is described below
commit 31cd6991558c45cea56ba25cb89f13e64e3d93fa
Author: Herman van Hovell <[email protected]>
AuthorDate: Tue Sep 17 23:00:49 2024 -0400
[SPARK-49424][CONNECT][SQL] Consolidate Encoders.scala
### What changes were proposed in this pull request?
This PR moves Encoders.scala to sql/api. It removes the duplicate one in
connect.
### Why are the changes needed?
We are creating a unified scala interface for Classic and Connect.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Existing tests.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #48021 from hvanhovell/SPARK-49424.
Authored-by: Herman van Hovell <[email protected]>
Signed-off-by: Herman van Hovell <[email protected]>
---
.../src/main/resources/error/error-conditions.json | 6 +
project/MimaExcludes.scala | 4 +
.../main/scala/org/apache/spark/sql/Encoders.scala | 152 +++++----
.../sql/catalyst/encoders/AgnosticEncoder.scala | 13 +
.../apache/spark/sql/errors/ExecutionErrors.scala | 18 ++
.../main/scala/org/apache/spark/sql/Encoders.scala | 348 ---------------------
.../sql/catalyst/encoders/ExpressionEncoder.scala | 51 +--
.../spark/sql/errors/QueryExecutionErrors.scala | 11 -
.../catalyst/encoders/EncoderResolutionSuite.scala | 8 +-
.../catalyst/encoders/ExpressionEncoderSuite.scala | 23 +-
.../sql/connect/planner/SparkConnectPlanner.scala | 17 +-
.../main/scala/org/apache/spark/sql/Dataset.scala | 84 +++--
.../apache/spark/sql/KeyValueGroupedDataset.scala | 37 +--
.../spark/sql/RelationalGroupedDataset.scala | 8 +-
.../spark/sql/execution/aggregate/udaf.scala | 2 +-
.../continuous/ContinuousTextSocketSource.scala | 6 +-
.../spark/sql/expressions/ReduceAggregator.scala | 11 +-
.../apache/spark/sql/internal/TypedAggUtils.scala | 8 +-
.../FlatMapGroupsWithStateExecHelperSuite.scala | 4 +-
.../execution/streaming/state/ListStateSuite.scala | 18 +-
.../execution/streaming/state/MapStateSuite.scala | 15 +-
.../state/StatefulProcessorHandleSuite.scala | 24 +-
.../sql/execution/streaming/state/TimerSuite.scala | 22 +-
.../streaming/state/ValueStateSuite.scala | 30 +-
24 files changed, 300 insertions(+), 620 deletions(-)
diff --git a/common/utils/src/main/resources/error/error-conditions.json
b/common/utils/src/main/resources/error/error-conditions.json
index 57b3d33741e9..25dd676c4aff 100644
--- a/common/utils/src/main/resources/error/error-conditions.json
+++ b/common/utils/src/main/resources/error/error-conditions.json
@@ -1996,6 +1996,12 @@
},
"sqlState" : "42903"
},
+ "INVALID_AGNOSTIC_ENCODER" : {
+ "message" : [
+ "Found an invalid agnostic encoder. Expects an instance of
AgnosticEncoder but got <encoderType>. For more information consult
'<docroot>/api/java/index.html?org/apache/spark/sql/Encoder.html'."
+ ],
+ "sqlState" : "42001"
+ },
"INVALID_ARRAY_INDEX" : {
"message" : [
"The index <indexValue> is out of bounds. The array has <arraySize>
elements. Use the SQL function `get()` to tolerate accessing element at invalid
index and return NULL instead. If necessary set <ansiConfig> to \"false\" to
bypass this error."
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 68433b501bcc..dfe7b14e2ec6 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -160,6 +160,10 @@ object MimaExcludes {
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.DataFrameWriterV2"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.WriteConfigMethods"),
+ // SPARK-49424: Shared Encoders
+
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.Encoders"),
+
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.Encoders$"),
+
// SPARK-49413: Create a shared RuntimeConfig interface.
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.RuntimeConfig"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.RuntimeConfig$"),
diff --git
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Encoders.scala
b/sql/api/src/main/scala/org/apache/spark/sql/Encoders.scala
similarity index 77%
rename from
connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Encoders.scala
rename to sql/api/src/main/scala/org/apache/spark/sql/Encoders.scala
index 33a322109c1b..9976b34f7a01 100644
---
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Encoders.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/Encoders.scala
@@ -14,95 +14,99 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
+
package org.apache.spark.sql
-import scala.reflect.ClassTag
+import java.lang.reflect.Modifier
+
+import scala.reflect.{classTag, ClassTag}
import scala.reflect.runtime.universe.TypeTag
import org.apache.spark.sql.catalyst.{JavaTypeInference, ScalaReflection}
-import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder,
JavaSerializationCodec, KryoSerializationCodec, RowEncoder => RowEncoderFactory}
+import org.apache.spark.sql.catalyst.encoders.{Codec, JavaSerializationCodec,
KryoSerializationCodec, RowEncoder => SchemaInference}
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.errors.ExecutionErrors
+import org.apache.spark.sql.types._
/**
* Methods for creating an [[Encoder]].
*
- * @since 3.5.0
+ * @since 1.6.0
*/
object Encoders {
/**
* An encoder for nullable boolean type. The Scala primitive encoder is
available as
* [[scalaBoolean]].
- * @since 3.5.0
+ * @since 1.6.0
*/
def BOOLEAN: Encoder[java.lang.Boolean] = BoxedBooleanEncoder
/**
* An encoder for nullable byte type. The Scala primitive encoder is
available as [[scalaByte]].
- * @since 3.5.0
+ * @since 1.6.0
*/
def BYTE: Encoder[java.lang.Byte] = BoxedByteEncoder
/**
* An encoder for nullable short type. The Scala primitive encoder is
available as
* [[scalaShort]].
- * @since 3.5.0
+ * @since 1.6.0
*/
def SHORT: Encoder[java.lang.Short] = BoxedShortEncoder
/**
* An encoder for nullable int type. The Scala primitive encoder is
available as [[scalaInt]].
- * @since 3.5.0
+ * @since 1.6.0
*/
def INT: Encoder[java.lang.Integer] = BoxedIntEncoder
/**
* An encoder for nullable long type. The Scala primitive encoder is
available as [[scalaLong]].
- * @since 3.5.0
+ * @since 1.6.0
*/
def LONG: Encoder[java.lang.Long] = BoxedLongEncoder
/**
* An encoder for nullable float type. The Scala primitive encoder is
available as
* [[scalaFloat]].
- * @since 3.5.0
+ * @since 1.6.0
*/
def FLOAT: Encoder[java.lang.Float] = BoxedFloatEncoder
/**
* An encoder for nullable double type. The Scala primitive encoder is
available as
* [[scalaDouble]].
- * @since 3.5.0
+ * @since 1.6.0
*/
def DOUBLE: Encoder[java.lang.Double] = BoxedDoubleEncoder
/**
* An encoder for nullable string type.
*
- * @since 3.5.0
+ * @since 1.6.0
*/
def STRING: Encoder[java.lang.String] = StringEncoder
/**
* An encoder for nullable decimal type.
*
- * @since 3.5.0
+ * @since 1.6.0
*/
def DECIMAL: Encoder[java.math.BigDecimal] = DEFAULT_JAVA_DECIMAL_ENCODER
/**
* An encoder for nullable date type.
*
- * @since 3.5.0
+ * @since 1.6.0
*/
- def DATE: Encoder[java.sql.Date] = DateEncoder(lenientSerialization = false)
+ def DATE: Encoder[java.sql.Date] = STRICT_DATE_ENCODER
/**
* Creates an encoder that serializes instances of the `java.time.LocalDate`
class to the
* internal representation of nullable Catalyst's DateType.
*
- * @since 3.5.0
+ * @since 3.0.0
*/
def LOCALDATE: Encoder[java.time.LocalDate] = STRICT_LOCAL_DATE_ENCODER
@@ -110,14 +114,14 @@ object Encoders {
* Creates an encoder that serializes instances of the
`java.time.LocalDateTime` class to the
* internal representation of nullable Catalyst's TimestampNTZType.
*
- * @since 3.5.0
+ * @since 3.4.0
*/
def LOCALDATETIME: Encoder[java.time.LocalDateTime] = LocalDateTimeEncoder
/**
* An encoder for nullable timestamp type.
*
- * @since 3.5.0
+ * @since 1.6.0
*/
def TIMESTAMP: Encoder[java.sql.Timestamp] = STRICT_TIMESTAMP_ENCODER
@@ -125,14 +129,14 @@ object Encoders {
* Creates an encoder that serializes instances of the `java.time.Instant`
class to the internal
* representation of nullable Catalyst's TimestampType.
*
- * @since 3.5.0
+ * @since 3.0.0
*/
def INSTANT: Encoder[java.time.Instant] = STRICT_INSTANT_ENCODER
/**
* An encoder for arrays of bytes.
*
- * @since 3.5.0
+ * @since 1.6.1
*/
def BINARY: Encoder[Array[Byte]] = BinaryEncoder
@@ -140,7 +144,7 @@ object Encoders {
* Creates an encoder that serializes instances of the `java.time.Duration`
class to the
* internal representation of nullable Catalyst's DayTimeIntervalType.
*
- * @since 3.5.0
+ * @since 3.2.0
*/
def DURATION: Encoder[java.time.Duration] = DayTimeIntervalEncoder
@@ -148,7 +152,7 @@ object Encoders {
* Creates an encoder that serializes instances of the `java.time.Period`
class to the internal
* representation of nullable Catalyst's YearMonthIntervalType.
*
- * @since 3.5.0
+ * @since 3.2.0
*/
def PERIOD: Encoder[java.time.Period] = YearMonthIntervalEncoder
@@ -166,7 +170,7 @@ object Encoders {
* - collection types: array, java.util.List, and map
* - nested java bean.
*
- * @since 3.5.0
+ * @since 1.6.0
*/
def bean[T](beanClass: Class[T]): Encoder[T] =
JavaTypeInference.encoderFor(beanClass)
@@ -175,71 +179,96 @@ object Encoders {
*
* @since 3.5.0
*/
- def row(schema: StructType): Encoder[Row] =
RowEncoderFactory.encoderFor(schema)
+ def row(schema: StructType): Encoder[Row] =
SchemaInference.encoderFor(schema)
/**
- * (Scala-specific) Creates an encoder that serializes objects of type T
using generic Java
- * serialization. This encoder maps T into a single byte array (binary)
field.
+ * (Scala-specific) Creates an encoder that serializes objects of type T
using Kryo. This
+ * encoder maps T into a single byte array (binary) field.
*
* T must be publicly accessible.
*
- * @note
- * This is extremely inefficient and should only be used as the last
resort.
- * @since 4.0.0
+ * @since 1.6.0
*/
- def javaSerialization[T: ClassTag]: Encoder[T] = {
- TransformingEncoder(implicitly[ClassTag[T]], BinaryEncoder,
JavaSerializationCodec)
- }
+ def kryo[T: ClassTag]: Encoder[T] = genericSerializer(KryoSerializationCodec)
/**
- * Creates an encoder that serializes objects of type T using generic Java
serialization. This
- * encoder maps T into a single byte array (binary) field.
+ * Creates an encoder that serializes objects of type T using Kryo. This
encoder maps T into a
+ * single byte array (binary) field.
+ *
+ * T must be publicly accessible.
+ *
+ * @since 1.6.0
+ */
+ def kryo[T](clazz: Class[T]): Encoder[T] = kryo(ClassTag[T](clazz))
+
+ /**
+ * (Scala-specific) Creates an encoder that serializes objects of type T
using generic Java
+ * serialization. This encoder maps T into a single byte array (binary)
field.
*
* T must be publicly accessible.
*
* @note
* This is extremely inefficient and should only be used as the last
resort.
- * @since 4.0.0
+ *
+ * @since 1.6.0
*/
- def javaSerialization[T](clazz: Class[T]): Encoder[T] =
javaSerialization(ClassTag[T](clazz))
+ def javaSerialization[T: ClassTag]: Encoder[T] =
genericSerializer(JavaSerializationCodec)
/**
- * (Scala-specific) Creates an encoder that serializes objects of type T
using Kryo. This
+ * Creates an encoder that serializes objects of type T using generic Java
serialization. This
* encoder maps T into a single byte array (binary) field.
*
* T must be publicly accessible.
*
- * @since 4.0.0
+ * @note
+ * This is extremely inefficient and should only be used as the last
resort.
+ *
+ * @since 1.6.0
*/
- def kryo[T: ClassTag]: Encoder[T] = {
- TransformingEncoder(implicitly[ClassTag[T]], BinaryEncoder,
KryoSerializationCodec)
+ def javaSerialization[T](clazz: Class[T]): Encoder[T] =
+ javaSerialization(ClassTag[T](clazz))
+
+ /** Throws an exception if T is not a public class. */
+ private def validatePublicClass[T: ClassTag](): Unit = {
+ if (!Modifier.isPublic(classTag[T].runtimeClass.getModifiers)) {
+ throw
ExecutionErrors.notPublicClassError(classTag[T].runtimeClass.getName)
+ }
+ }
+
+ /** A way to construct encoders using generic serializers. */
+ private def genericSerializer[T: ClassTag](
+ provider: () => Codec[Any, Array[Byte]]): Encoder[T] = {
+ if (classTag[T].runtimeClass.isPrimitive) {
+ throw ExecutionErrors.primitiveTypesNotSupportedError()
+ }
+
+ validatePublicClass[T]()
+
+ TransformingEncoder(classTag[T], BinaryEncoder, provider)
+ }
+
+ private[sql] def tupleEncoder[T](encoders: Encoder[_]*): Encoder[T] = {
+
ProductEncoder.tuple(encoders.map(agnosticEncoderFor(_))).asInstanceOf[Encoder[T]]
}
/**
- * Creates an encoder that serializes objects of type T using Kryo. This
encoder maps T into a
- * single byte array (binary) field.
- *
- * T must be publicly accessible.
+ * An encoder for 1-ary tuples.
*
* @since 4.0.0
*/
- def kryo[T](clazz: Class[T]): Encoder[T] = kryo(ClassTag[T](clazz))
-
- private def tupleEncoder[T](encoders: Encoder[_]*): Encoder[T] = {
-
ProductEncoder.tuple(encoders.asInstanceOf[Seq[AgnosticEncoder[_]]]).asInstanceOf[Encoder[T]]
- }
+ def tuple[T1](e1: Encoder[T1]): Encoder[(T1)] = tupleEncoder(e1)
/**
* An encoder for 2-ary tuples.
*
- * @since 3.5.0
+ * @since 1.6.0
*/
def tuple[T1, T2](e1: Encoder[T1], e2: Encoder[T2]): Encoder[(T1, T2)] =
tupleEncoder(e1, e2)
/**
* An encoder for 3-ary tuples.
*
- * @since 3.5.0
+ * @since 1.6.0
*/
def tuple[T1, T2, T3](
e1: Encoder[T1],
@@ -249,7 +278,7 @@ object Encoders {
/**
* An encoder for 4-ary tuples.
*
- * @since 3.5.0
+ * @since 1.6.0
*/
def tuple[T1, T2, T3, T4](
e1: Encoder[T1],
@@ -260,7 +289,7 @@ object Encoders {
/**
* An encoder for 5-ary tuples.
*
- * @since 3.5.0
+ * @since 1.6.0
*/
def tuple[T1, T2, T3, T4, T5](
e1: Encoder[T1],
@@ -271,49 +300,50 @@ object Encoders {
/**
* An encoder for Scala's product type (tuples, case classes, etc).
- * @since 3.5.0
+ * @since 2.0.0
*/
def product[T <: Product: TypeTag]: Encoder[T] =
ScalaReflection.encoderFor[T]
/**
* An encoder for Scala's primitive int type.
- * @since 3.5.0
+ * @since 2.0.0
*/
def scalaInt: Encoder[Int] = PrimitiveIntEncoder
/**
* An encoder for Scala's primitive long type.
- * @since 3.5.0
+ * @since 2.0.0
*/
def scalaLong: Encoder[Long] = PrimitiveLongEncoder
/**
* An encoder for Scala's primitive double type.
- * @since 3.5.0
+ * @since 2.0.0
*/
def scalaDouble: Encoder[Double] = PrimitiveDoubleEncoder
/**
* An encoder for Scala's primitive float type.
- * @since 3.5.0
+ * @since 2.0.0
*/
def scalaFloat: Encoder[Float] = PrimitiveFloatEncoder
/**
* An encoder for Scala's primitive byte type.
- * @since 3.5.0
+ * @since 2.0.0
*/
def scalaByte: Encoder[Byte] = PrimitiveByteEncoder
/**
* An encoder for Scala's primitive short type.
- * @since 3.5.0
+ * @since 2.0.0
*/
def scalaShort: Encoder[Short] = PrimitiveShortEncoder
/**
* An encoder for Scala's primitive boolean type.
- * @since 3.5.0
+ * @since 2.0.0
*/
def scalaBoolean: Encoder[Boolean] = PrimitiveBooleanEncoder
+
}
diff --git
a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala
b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala
index a57849575549..10f734b3f84e 100644
---
a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala
+++
b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala
@@ -46,7 +46,20 @@ trait AgnosticEncoder[T] extends Encoder[T] {
def isStruct: Boolean = false
}
+/**
+ * Extract an [[AgnosticEncoder]] from an [[Encoder]].
+ */
+trait ToAgnosticEncoder[T] {
+ def encoder: AgnosticEncoder[T]
+}
+
object AgnosticEncoders {
+ def agnosticEncoderFor[T: Encoder]: AgnosticEncoder[T] =
implicitly[Encoder[T]] match {
+ case a: AgnosticEncoder[T] => a
+ case e: ToAgnosticEncoder[T @unchecked] => e.encoder
+ case other => throw ExecutionErrors.invalidAgnosticEncoderError(other)
+ }
+
case class OptionEncoder[E](elementEncoder: AgnosticEncoder[E])
extends AgnosticEncoder[Option[E]] {
override def isPrimitive: Boolean = false
diff --git
a/sql/api/src/main/scala/org/apache/spark/sql/errors/ExecutionErrors.scala
b/sql/api/src/main/scala/org/apache/spark/sql/errors/ExecutionErrors.scala
index 4890ff4431fe..698a7b096e1a 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/errors/ExecutionErrors.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/errors/ExecutionErrors.scala
@@ -217,9 +217,27 @@ private[sql] trait ExecutionErrors extends
DataTypeErrorsBase {
new SparkRuntimeException(errorClass = "CANNOT_USE_KRYO",
messageParameters = Map.empty)
}
+ def notPublicClassError(name: String): SparkUnsupportedOperationException = {
+ new SparkUnsupportedOperationException(
+ errorClass = "_LEGACY_ERROR_TEMP_2229",
+ messageParameters = Map("name" -> name))
+ }
+
+ def primitiveTypesNotSupportedError(): SparkUnsupportedOperationException = {
+ new SparkUnsupportedOperationException(errorClass =
"_LEGACY_ERROR_TEMP_2230")
+ }
+
def elementsOfTupleExceedLimitError(): SparkUnsupportedOperationException = {
new SparkUnsupportedOperationException("_LEGACY_ERROR_TEMP_2150")
}
+
+ def invalidAgnosticEncoderError(encoder: AnyRef): Throwable = {
+ new SparkRuntimeException(
+ errorClass = "INVALID_AGNOSTIC_ENCODER",
+ messageParameters = Map(
+ "encoderType" -> encoder.getClass.getName,
+ "docroot" -> SparkBuildInfo.spark_doc_root))
+ }
}
private[sql] object ExecutionErrors extends ExecutionErrors
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala
deleted file mode 100644
index 7e040f6232fb..000000000000
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala
+++ /dev/null
@@ -1,348 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql
-
-import java.lang.reflect.Modifier
-
-import scala.reflect.{classTag, ClassTag}
-import scala.reflect.runtime.universe.TypeTag
-
-import org.apache.spark.sql.catalyst.encoders.{encoderFor, Codec,
ExpressionEncoder, JavaSerializationCodec, KryoSerializationCodec}
-import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BinaryEncoder,
TransformingEncoder}
-import org.apache.spark.sql.errors.QueryExecutionErrors
-import org.apache.spark.sql.types._
-
-/**
- * Methods for creating an [[Encoder]].
- *
- * @since 1.6.0
- */
-object Encoders {
-
- /**
- * An encoder for nullable boolean type.
- * The Scala primitive encoder is available as [[scalaBoolean]].
- * @since 1.6.0
- */
- def BOOLEAN: Encoder[java.lang.Boolean] = ExpressionEncoder()
-
- /**
- * An encoder for nullable byte type.
- * The Scala primitive encoder is available as [[scalaByte]].
- * @since 1.6.0
- */
- def BYTE: Encoder[java.lang.Byte] = ExpressionEncoder()
-
- /**
- * An encoder for nullable short type.
- * The Scala primitive encoder is available as [[scalaShort]].
- * @since 1.6.0
- */
- def SHORT: Encoder[java.lang.Short] = ExpressionEncoder()
-
- /**
- * An encoder for nullable int type.
- * The Scala primitive encoder is available as [[scalaInt]].
- * @since 1.6.0
- */
- def INT: Encoder[java.lang.Integer] = ExpressionEncoder()
-
- /**
- * An encoder for nullable long type.
- * The Scala primitive encoder is available as [[scalaLong]].
- * @since 1.6.0
- */
- def LONG: Encoder[java.lang.Long] = ExpressionEncoder()
-
- /**
- * An encoder for nullable float type.
- * The Scala primitive encoder is available as [[scalaFloat]].
- * @since 1.6.0
- */
- def FLOAT: Encoder[java.lang.Float] = ExpressionEncoder()
-
- /**
- * An encoder for nullable double type.
- * The Scala primitive encoder is available as [[scalaDouble]].
- * @since 1.6.0
- */
- def DOUBLE: Encoder[java.lang.Double] = ExpressionEncoder()
-
- /**
- * An encoder for nullable string type.
- *
- * @since 1.6.0
- */
- def STRING: Encoder[java.lang.String] = ExpressionEncoder()
-
- /**
- * An encoder for nullable decimal type.
- *
- * @since 1.6.0
- */
- def DECIMAL: Encoder[java.math.BigDecimal] = ExpressionEncoder()
-
- /**
- * An encoder for nullable date type.
- *
- * @since 1.6.0
- */
- def DATE: Encoder[java.sql.Date] = ExpressionEncoder()
-
- /**
- * Creates an encoder that serializes instances of the `java.time.LocalDate`
class
- * to the internal representation of nullable Catalyst's DateType.
- *
- * @since 3.0.0
- */
- def LOCALDATE: Encoder[java.time.LocalDate] = ExpressionEncoder()
-
- /**
- * Creates an encoder that serializes instances of the
`java.time.LocalDateTime` class
- * to the internal representation of nullable Catalyst's TimestampNTZType.
- *
- * @since 3.4.0
- */
- def LOCALDATETIME: Encoder[java.time.LocalDateTime] = ExpressionEncoder()
-
- /**
- * An encoder for nullable timestamp type.
- *
- * @since 1.6.0
- */
- def TIMESTAMP: Encoder[java.sql.Timestamp] = ExpressionEncoder()
-
- /**
- * Creates an encoder that serializes instances of the `java.time.Instant`
class
- * to the internal representation of nullable Catalyst's TimestampType.
- *
- * @since 3.0.0
- */
- def INSTANT: Encoder[java.time.Instant] = ExpressionEncoder()
-
- /**
- * An encoder for arrays of bytes.
- *
- * @since 1.6.1
- */
- def BINARY: Encoder[Array[Byte]] = ExpressionEncoder()
-
- /**
- * Creates an encoder that serializes instances of the `java.time.Duration`
class
- * to the internal representation of nullable Catalyst's DayTimeIntervalType.
- *
- * @since 3.2.0
- */
- def DURATION: Encoder[java.time.Duration] = ExpressionEncoder()
-
- /**
- * Creates an encoder that serializes instances of the `java.time.Period`
class
- * to the internal representation of nullable Catalyst's
YearMonthIntervalType.
- *
- * @since 3.2.0
- */
- def PERIOD: Encoder[java.time.Period] = ExpressionEncoder()
-
- /**
- * Creates an encoder for Java Bean of type T.
- *
- * T must be publicly accessible.
- *
- * supported types for java bean field:
- * - primitive types: boolean, int, double, etc.
- * - boxed types: Boolean, Integer, Double, etc.
- * - String
- * - java.math.BigDecimal, java.math.BigInteger
- * - time related: java.sql.Date, java.sql.Timestamp, java.time.LocalDate,
java.time.Instant
- * - collection types: array, java.util.List, and map
- * - nested java bean.
- *
- * @since 1.6.0
- */
- def bean[T](beanClass: Class[T]): Encoder[T] =
ExpressionEncoder.javaBean(beanClass)
-
- /**
- * Creates a [[Row]] encoder for schema `schema`.
- *
- * @since 3.5.0
- */
- def row(schema: StructType): Encoder[Row] = ExpressionEncoder(schema)
-
- /**
- * (Scala-specific) Creates an encoder that serializes objects of type T
using Kryo.
- * This encoder maps T into a single byte array (binary) field.
- *
- * T must be publicly accessible.
- *
- * @since 1.6.0
- */
- def kryo[T: ClassTag]: Encoder[T] = genericSerializer(KryoSerializationCodec)
-
- /**
- * Creates an encoder that serializes objects of type T using Kryo.
- * This encoder maps T into a single byte array (binary) field.
- *
- * T must be publicly accessible.
- *
- * @since 1.6.0
- */
- def kryo[T](clazz: Class[T]): Encoder[T] = kryo(ClassTag[T](clazz))
-
- /**
- * (Scala-specific) Creates an encoder that serializes objects of type T
using generic Java
- * serialization. This encoder maps T into a single byte array (binary)
field.
- *
- * T must be publicly accessible.
- *
- * @note This is extremely inefficient and should only be used as the last
resort.
- *
- * @since 1.6.0
- */
- def javaSerialization[T: ClassTag]: Encoder[T] =
genericSerializer(JavaSerializationCodec)
-
- /**
- * Creates an encoder that serializes objects of type T using generic Java
serialization.
- * This encoder maps T into a single byte array (binary) field.
- *
- * T must be publicly accessible.
- *
- * @note This is extremely inefficient and should only be used as the last
resort.
- *
- * @since 1.6.0
- */
- def javaSerialization[T](clazz: Class[T]): Encoder[T] =
javaSerialization(ClassTag[T](clazz))
-
- /** Throws an exception if T is not a public class. */
- private def validatePublicClass[T: ClassTag](): Unit = {
- if (!Modifier.isPublic(classTag[T].runtimeClass.getModifiers)) {
- throw
QueryExecutionErrors.notPublicClassError(classTag[T].runtimeClass.getName)
- }
- }
-
- /** A way to construct encoders using generic serializers. */
- private def genericSerializer[T: ClassTag](
- provider: () => Codec[Any, Array[Byte]]): Encoder[T] = {
- if (classTag[T].runtimeClass.isPrimitive) {
- throw QueryExecutionErrors.primitiveTypesNotSupportedError()
- }
-
- validatePublicClass[T]()
-
- ExpressionEncoder(TransformingEncoder(classTag[T], BinaryEncoder,
provider))
- }
-
- /**
- * An encoder for 2-ary tuples.
- *
- * @since 1.6.0
- */
- def tuple[T1, T2](
- e1: Encoder[T1],
- e2: Encoder[T2]): Encoder[(T1, T2)] = {
- ExpressionEncoder.tuple(encoderFor(e1), encoderFor(e2))
- }
-
- /**
- * An encoder for 3-ary tuples.
- *
- * @since 1.6.0
- */
- def tuple[T1, T2, T3](
- e1: Encoder[T1],
- e2: Encoder[T2],
- e3: Encoder[T3]): Encoder[(T1, T2, T3)] = {
- ExpressionEncoder.tuple(encoderFor(e1), encoderFor(e2), encoderFor(e3))
- }
-
- /**
- * An encoder for 4-ary tuples.
- *
- * @since 1.6.0
- */
- def tuple[T1, T2, T3, T4](
- e1: Encoder[T1],
- e2: Encoder[T2],
- e3: Encoder[T3],
- e4: Encoder[T4]): Encoder[(T1, T2, T3, T4)] = {
- ExpressionEncoder.tuple(encoderFor(e1), encoderFor(e2), encoderFor(e3),
encoderFor(e4))
- }
-
- /**
- * An encoder for 5-ary tuples.
- *
- * @since 1.6.0
- */
- def tuple[T1, T2, T3, T4, T5](
- e1: Encoder[T1],
- e2: Encoder[T2],
- e3: Encoder[T3],
- e4: Encoder[T4],
- e5: Encoder[T5]): Encoder[(T1, T2, T3, T4, T5)] = {
- ExpressionEncoder.tuple(
- encoderFor(e1), encoderFor(e2), encoderFor(e3), encoderFor(e4),
encoderFor(e5))
- }
-
- /**
- * An encoder for Scala's product type (tuples, case classes, etc).
- * @since 2.0.0
- */
- def product[T <: Product : TypeTag]: Encoder[T] = ExpressionEncoder()
-
- /**
- * An encoder for Scala's primitive int type.
- * @since 2.0.0
- */
- def scalaInt: Encoder[Int] = ExpressionEncoder()
-
- /**
- * An encoder for Scala's primitive long type.
- * @since 2.0.0
- */
- def scalaLong: Encoder[Long] = ExpressionEncoder()
-
- /**
- * An encoder for Scala's primitive double type.
- * @since 2.0.0
- */
- def scalaDouble: Encoder[Double] = ExpressionEncoder()
-
- /**
- * An encoder for Scala's primitive float type.
- * @since 2.0.0
- */
- def scalaFloat: Encoder[Float] = ExpressionEncoder()
-
- /**
- * An encoder for Scala's primitive byte type.
- * @since 2.0.0
- */
- def scalaByte: Encoder[Byte] = ExpressionEncoder()
-
- /**
- * An encoder for Scala's primitive short type.
- * @since 2.0.0
- */
- def scalaShort: Encoder[Short] = ExpressionEncoder()
-
- /**
- * An encoder for Scala's primitive boolean type.
- * @since 2.0.0
- */
- def scalaBoolean: Encoder[Boolean] = ExpressionEncoder()
-
-}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
index 8e39ae0389c2..d7d53230470d 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
@@ -70,54 +70,6 @@ object ExpressionEncoder {
apply(JavaTypeInference.encoderFor(beanClass))
}
- /**
- * Given a set of N encoders, constructs a new encoder that produce objects
as items in an
- * N-tuple. Note that these encoders should be unresolved so that
information about
- * name/positional binding is preserved.
- * When `useNullSafeDeserializer` is true, the deserialization result for a
child will be null if
- * the input is null. It is false by default as most deserializers handle
null input properly and
- * don't require an extra null check. Some of them are null-tolerant, such
as the deserializer for
- * `Option[T]`, and we must not set it to true in this case.
- */
- def tuple(
- encoders: Seq[ExpressionEncoder[_]],
- useNullSafeDeserializer: Boolean = false): ExpressionEncoder[_] = {
- val tupleEncoder = AgnosticEncoders.ProductEncoder.tuple(
- encoders.map(_.encoder),
- useNullSafeDeserializer)
- ExpressionEncoder(tupleEncoder)
- }
-
- // Tuple1
- def tuple[T](e: ExpressionEncoder[T]): ExpressionEncoder[Tuple1[T]] =
- tuple(Seq(e)).asInstanceOf[ExpressionEncoder[Tuple1[T]]]
-
- def tuple[T1, T2](
- e1: ExpressionEncoder[T1],
- e2: ExpressionEncoder[T2]): ExpressionEncoder[(T1, T2)] =
- tuple(Seq(e1, e2)).asInstanceOf[ExpressionEncoder[(T1, T2)]]
-
- def tuple[T1, T2, T3](
- e1: ExpressionEncoder[T1],
- e2: ExpressionEncoder[T2],
- e3: ExpressionEncoder[T3]): ExpressionEncoder[(T1, T2, T3)] =
- tuple(Seq(e1, e2, e3)).asInstanceOf[ExpressionEncoder[(T1, T2, T3)]]
-
- def tuple[T1, T2, T3, T4](
- e1: ExpressionEncoder[T1],
- e2: ExpressionEncoder[T2],
- e3: ExpressionEncoder[T3],
- e4: ExpressionEncoder[T4]): ExpressionEncoder[(T1, T2, T3, T4)] =
- tuple(Seq(e1, e2, e3, e4)).asInstanceOf[ExpressionEncoder[(T1, T2, T3,
T4)]]
-
- def tuple[T1, T2, T3, T4, T5](
- e1: ExpressionEncoder[T1],
- e2: ExpressionEncoder[T2],
- e3: ExpressionEncoder[T3],
- e4: ExpressionEncoder[T4],
- e5: ExpressionEncoder[T5]): ExpressionEncoder[(T1, T2, T3, T4, T5)] =
- tuple(Seq(e1, e2, e3, e4, e5)).asInstanceOf[ExpressionEncoder[(T1, T2, T3,
T4, T5)]]
-
private val anyObjectType = ObjectType(classOf[Any])
/**
@@ -189,7 +141,8 @@ case class ExpressionEncoder[T](
encoder: AgnosticEncoder[T],
objSerializer: Expression,
objDeserializer: Expression)
- extends Encoder[T] {
+ extends Encoder[T]
+ with ToAgnosticEncoder[T] {
override def clsTag: ClassTag[T] = encoder.clsTag
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
index 2ab86a5c5f03..4bc071155012 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
@@ -1876,17 +1876,6 @@ private[sql] object QueryExecutionErrors extends
QueryErrorsBase with ExecutionE
cause = null)
}
- def notPublicClassError(name: String): SparkUnsupportedOperationException = {
- new SparkUnsupportedOperationException(
- errorClass = "_LEGACY_ERROR_TEMP_2229",
- messageParameters = Map(
- "name" -> name))
- }
-
- def primitiveTypesNotSupportedError(): SparkUnsupportedOperationException = {
- new SparkUnsupportedOperationException(errorClass =
"_LEGACY_ERROR_TEMP_2230")
- }
-
def onlySupportDataSourcesProvidingFileFormatError(providingClass: String):
Throwable = {
new SparkException(
errorClass = "_LEGACY_ERROR_TEMP_2233",
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala
index 28796db7c02e..35a27f41da80 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.encoders
import scala.reflect.runtime.universe.TypeTag
import org.apache.spark.SparkRuntimeException
-import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.{AnalysisException, Encoders}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.{Attribute,
AttributeReference}
@@ -71,9 +71,9 @@ class EncoderResolutionSuite extends PlanTest {
}
test("real type doesn't match encoder schema but they are compatible: tupled
encoder") {
- val encoder = ExpressionEncoder.tuple(
- ExpressionEncoder[StringLongClass](),
- ExpressionEncoder[Long]())
+ val encoder = encoderFor(Encoders.tuple(
+ Encoders.product[StringLongClass],
+ Encoders.scalaLong))
val attrs = Seq($"a".struct($"a".string, $"b".byte), $"b".int)
testFromRow(encoder, attrs, InternalRow(InternalRow(str, 1.toByte), 2))
}
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
index 0c0c7f12f176..3b5cbed2cc52 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
@@ -321,29 +321,29 @@ class ExpressionEncoderSuite extends
CodegenInterpretedPlanTest with AnalysisTes
encodeDecodeTest(
1 -> 10L,
"tuple with 2 flat encoders")(
- ExpressionEncoder.tuple(ExpressionEncoder[Int](),
ExpressionEncoder[Long]()))
+ encoderFor(Encoders.tuple(Encoders.scalaInt, Encoders.scalaLong)))
encodeDecodeTest(
(PrimitiveData(1, 1, 1, 1, 1, 1, true), (3, 30L)),
"tuple with 2 product encoders")(
- ExpressionEncoder.tuple(ExpressionEncoder[PrimitiveData](),
ExpressionEncoder[(Int, Long)]()))
+ encoderFor(Encoders.tuple(Encoders.product[PrimitiveData],
Encoders.product[(Int, Long)])))
encodeDecodeTest(
(PrimitiveData(1, 1, 1, 1, 1, 1, true), 3),
"tuple with flat encoder and product encoder")(
- ExpressionEncoder.tuple(ExpressionEncoder[PrimitiveData](),
ExpressionEncoder[Int]()))
+ encoderFor(Encoders.tuple(Encoders.product[PrimitiveData],
Encoders.scalaInt)))
encodeDecodeTest(
(3, PrimitiveData(1, 1, 1, 1, 1, 1, true)),
"tuple with product encoder and flat encoder")(
- ExpressionEncoder.tuple(ExpressionEncoder[Int](),
ExpressionEncoder[PrimitiveData]()))
+ encoderFor(Encoders.tuple(Encoders.scalaInt,
Encoders.product[PrimitiveData])))
encodeDecodeTest(
(1, (10, 100L)),
"nested tuple encoder") {
- val intEnc = ExpressionEncoder[Int]()
- val longEnc = ExpressionEncoder[Long]()
- ExpressionEncoder.tuple(intEnc, ExpressionEncoder.tuple(intEnc, longEnc))
+ val intEnc = Encoders.scalaInt
+ val longEnc = Encoders.scalaLong
+ encoderFor(Encoders.tuple(intEnc, Encoders.tuple(intEnc, longEnc)))
}
// test for value classes
@@ -468,9 +468,8 @@ class ExpressionEncoderSuite extends
CodegenInterpretedPlanTest with AnalysisTes
// test for tupled encoders
{
- val schema = ExpressionEncoder.tuple(
- ExpressionEncoder[Int](),
- ExpressionEncoder[(String, Int)]()).schema
+ val encoder = encoderFor(Encoders.tuple(Encoders.scalaInt,
Encoders.product[(String, Int)]))
+ val schema = encoder.schema
assert(schema(0).nullable === false)
assert(schema(1).nullable)
assert(schema(1).dataType.asInstanceOf[StructType](0).nullable)
@@ -513,11 +512,11 @@ class ExpressionEncoderSuite extends
CodegenInterpretedPlanTest with AnalysisTes
}
test("throw exception for tuples with more than 22 elements") {
- val encoders = (0 to 22).map(_ =>
Encoders.scalaInt.asInstanceOf[ExpressionEncoder[_]])
+ val encoders = (0 to 22).map(_ => Encoders.scalaInt)
checkError(
exception = intercept[SparkUnsupportedOperationException] {
- ExpressionEncoder.tuple(encoders)
+ Encoders.tupleEncoder(encoders: _*)
},
condition = "_LEGACY_ERROR_TEMP_2150",
parameters = Map.empty)
diff --git
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index bb6d52308c19..33c9edb1cd21 100644
---
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -47,7 +47,7 @@ import org.apache.spark.sql.{Dataset, Encoders,
ForeachWriter, Observation, Rela
import org.apache.spark.sql.avro.{AvroDataToCatalyst, CatalystDataToAvro}
import org.apache.spark.sql.catalyst.{expressions, AliasIdentifier,
FunctionIdentifier, QueryPlanningTracker}
import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry,
GlobalTempView, LocalTempView, MultiAlias, NameParameterizedQuery,
PosParameterizedQuery, UnresolvedAlias, UnresolvedAttribute,
UnresolvedDataFrameStar, UnresolvedDeserializer, UnresolvedExtractValue,
UnresolvedFunction, UnresolvedRegex, UnresolvedRelation, UnresolvedStar,
UnresolvedTranspose}
-import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder,
ExpressionEncoder, RowEncoder}
+import org.apache.spark.sql.catalyst.encoders.{encoderFor, AgnosticEncoder,
ExpressionEncoder, RowEncoder}
import
org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.UnboundRowEncoder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
@@ -2318,16 +2318,17 @@ class SparkConnectPlanner(
if (fun.getArgumentsCount != 1) {
throw InvalidPlanInput("reduce requires single child expression")
}
- val udf = fun.getArgumentsList.asScala.map(transformExpression) match {
- case collection.Seq(f: ScalaUDF) =>
- f
+ val udf = fun.getArgumentsList.asScala match {
+ case collection.Seq(e)
+ if e.hasCommonInlineUserDefinedFunction &&
+ e.getCommonInlineUserDefinedFunction.hasScalarScalaUdf =>
+ unpackUdf(e.getCommonInlineUserDefinedFunction)
case other =>
throw InvalidPlanInput(s"reduce should carry a scalar scala udf, but
got $other")
}
- assert(udf.outputEncoder.isDefined)
- val tEncoder = udf.outputEncoder.get // (T, T) => T
- val reduce = ReduceAggregator(udf.function)(tEncoder).toColumn.expr
- TypedAggUtils.withInputType(reduce, tEncoder, dataAttributes)
+ val encoder = udf.outputEncoder
+ val reduce = ReduceAggregator(udf.function)(encoder).toColumn.expr
+ TypedAggUtils.withInputType(reduce, encoderFor(encoder), dataAttributes)
}
private def transformExpressionWithTypedReduceExpression(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 6e5dcc24e29d..c147b6a56e02 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -42,6 +42,7 @@ import org.apache.spark.sql.catalyst.{CatalystTypeConverters,
InternalRow, Query
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.catalog.HiveTableRelation
import org.apache.spark.sql.catalyst.encoders._
+import
org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{agnosticEncoderFor,
ProductEncoder, StructEncoder}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.json.{JacksonGenerator, JSONOptions}
import org.apache.spark.sql.catalyst.parser.{ParseException, ParserUtils}
@@ -78,13 +79,14 @@ private[sql] object Dataset {
val DATASET_ID_TAG = TreeNodeTag[HashSet[Long]]("dataset_id")
def apply[T: Encoder](sparkSession: SparkSession, logicalPlan: LogicalPlan):
Dataset[T] = {
- val dataset = new Dataset(sparkSession, logicalPlan,
implicitly[Encoder[T]])
+ val encoder = implicitly[Encoder[T]]
+ val dataset = new Dataset(sparkSession, logicalPlan, encoder)
// Eagerly bind the encoder so we verify that the encoder matches the
underlying
// schema. The user will get an error if this is not the case.
// optimization: it is guaranteed that [[InternalRow]] can be converted to
[[Row]] so
// do not do this check in that case. this check can be expensive since it
requires running
// the whole [[Analyzer]] to resolve the deserializer
- if (dataset.exprEnc.clsTag.runtimeClass != classOf[Row]) {
+ if (dataset.encoder.clsTag.runtimeClass != classOf[Row]) {
dataset.resolvedEnc
}
dataset
@@ -94,7 +96,7 @@ private[sql] object Dataset {
sparkSession.withActive {
val qe = sparkSession.sessionState.executePlan(logicalPlan)
qe.assertAnalyzed()
- new Dataset[Row](qe, ExpressionEncoder(qe.analyzed.schema))
+ new Dataset[Row](qe, RowEncoder.encoderFor(qe.analyzed.schema))
}
def ofRows(
@@ -105,7 +107,7 @@ private[sql] object Dataset {
val qe = new QueryExecution(
sparkSession, logicalPlan, shuffleCleanupMode = shuffleCleanupMode)
qe.assertAnalyzed()
- new Dataset[Row](qe, ExpressionEncoder(qe.analyzed.schema))
+ new Dataset[Row](qe, RowEncoder.encoderFor(qe.analyzed.schema))
}
/** A variant of ofRows that allows passing in a tracker so we can track
query parsing time. */
@@ -118,7 +120,7 @@ private[sql] object Dataset {
val qe = new QueryExecution(
sparkSession, logicalPlan, tracker, shuffleCleanupMode =
shuffleCleanupMode)
qe.assertAnalyzed()
- new Dataset[Row](qe, ExpressionEncoder(qe.analyzed.schema))
+ new Dataset[Row](qe, RowEncoder.encoderFor(qe.analyzed.schema))
}
}
@@ -252,12 +254,17 @@ class Dataset[T] private[sql](
}
/**
- * Currently [[ExpressionEncoder]] is the only implementation of
[[Encoder]], here we turn the
- * passed in encoder to [[ExpressionEncoder]] explicitly, and mark it
implicit so that we can use
- * it when constructing new Dataset objects that have the same object type
(that will be
- * possibly resolved to a different schema).
+ * Expose the encoder as implicit so it can be used to construct new Dataset
objects that have
+ * the same external type.
*/
- private[sql] implicit val exprEnc: ExpressionEncoder[T] = encoderFor(encoder)
+ private implicit def encoderImpl: Encoder[T] = encoder
+
+ /**
+ * The actual [[ExpressionEncoder]] used by the dataset. This and its
resolved counterpart should
+ * only be used for actual (de)serialization, the binding of Aggregator
inputs, and in the rare
+ * cases where a plan needs to be constructed with an ExpressionEncoder.
+ */
+ private[sql] lazy val exprEnc: ExpressionEncoder[T] = encoderFor(encoder)
// The resolved `ExpressionEncoder` which can be used to turn rows to
objects of type T, after
// collecting rows to the driver side.
@@ -265,7 +272,7 @@ class Dataset[T] private[sql](
exprEnc.resolveAndBind(logicalPlan.output,
sparkSession.sessionState.analyzer)
}
- private implicit def classTag: ClassTag[T] = exprEnc.clsTag
+ private implicit def classTag: ClassTag[T] = encoder.clsTag
// sqlContext must be val because a stable identifier is expected when you
import implicits
@transient lazy val sqlContext: SQLContext = sparkSession.sqlContext
@@ -476,7 +483,7 @@ class Dataset[T] private[sql](
/** @inheritdoc */
// This is declared with parentheses to prevent the Scala compiler from
treating
// `ds.toDF("1")` as invoking this toDF and then apply on the returned
DataFrame.
- def toDF(): DataFrame = new Dataset[Row](queryExecution,
ExpressionEncoder(schema))
+ def toDF(): DataFrame = new Dataset[Row](queryExecution,
RowEncoder.encoderFor(schema))
/** @inheritdoc */
def as[U : Encoder]: Dataset[U] = Dataset[U](sparkSession, logicalPlan)
@@ -671,17 +678,17 @@ class Dataset[T] private[sql](
Some(condition.expr),
JoinHint.NONE)).analyzed.asInstanceOf[Join]
- implicit val tuple2Encoder: Encoder[(T, U)] =
- ExpressionEncoder
- .tuple(Seq(this.exprEnc, other.exprEnc), useNullSafeDeserializer =
true)
- .asInstanceOf[Encoder[(T, U)]]
-
- withTypedPlan(JoinWith.typedJoinWith(
+ val leftEncoder = agnosticEncoderFor(encoder)
+ val rightEncoder = agnosticEncoderFor(other.encoder)
+ val joinEncoder = ProductEncoder.tuple(Seq(leftEncoder, rightEncoder),
elementsCanBeNull = true)
+ .asInstanceOf[Encoder[(T, U)]]
+ val joinWith = JoinWith.typedJoinWith(
joined,
sparkSession.sessionState.conf.dataFrameSelfJoinAutoResolveAmbiguity,
sparkSession.sessionState.analyzer.resolver,
- this.exprEnc.isSerializedAsStructForTopLevel,
- other.exprEnc.isSerializedAsStructForTopLevel))
+ leftEncoder.isStruct,
+ rightEncoder.isStruct)
+ new Dataset(sparkSession, joinWith, joinEncoder)
}
// TODO(SPARK-22947): Fix the DataFrame API.
@@ -826,24 +833,29 @@ class Dataset[T] private[sql](
/** @inheritdoc */
def select[U1](c1: TypedColumn[T, U1]): Dataset[U1] = {
- implicit val encoder: ExpressionEncoder[U1] = encoderFor(c1.encoder)
+ val encoder = agnosticEncoderFor(c1.encoder)
val tc1 = withInputType(c1.named, exprEnc, logicalPlan.output)
val project = Project(tc1 :: Nil, logicalPlan)
- if (!encoder.isSerializedAsStructForTopLevel) {
- new Dataset[U1](sparkSession, project, encoder)
- } else {
- // Flattens inner fields of U1
- new Dataset[Tuple1[U1]](sparkSession, project,
ExpressionEncoder.tuple(encoder)).map(_._1)
+ val plan = encoder match {
+ case se: StructEncoder[U1] =>
+ // Flatten the result.
+ val attribute = GetColumnByOrdinal(0, se.dataType)
+ val projectList = se.fields.zipWithIndex.map {
+ case (field, index) =>
+ Alias(GetStructField(attribute, index, None), field.name)()
+ }
+ Project(projectList, project)
+ case _ => project
}
+ new Dataset[U1](sparkSession, plan, encoder)
}
/** @inheritdoc */
protected def selectUntyped(columns: TypedColumn[_, _]*): Dataset[_] = {
- val encoders = columns.map(c => encoderFor(c.encoder))
+ val encoders = columns.map(c => agnosticEncoderFor(c.encoder))
val namedColumns = columns.map(c => withInputType(c.named, exprEnc,
logicalPlan.output))
- val execution = new QueryExecution(sparkSession, Project(namedColumns,
logicalPlan))
- new Dataset(execution, ExpressionEncoder.tuple(encoders))
+ new Dataset(sparkSession, Project(namedColumns, logicalPlan),
ProductEncoder.tuple(encoders))
}
/** @inheritdoc */
@@ -912,8 +924,8 @@ class Dataset[T] private[sql](
val executed = sparkSession.sessionState.executePlan(withGroupingKey)
new KeyValueGroupedDataset(
- encoderFor[K],
- encoderFor[T],
+ implicitly[Encoder[K]],
+ encoder,
executed,
logicalPlan.output,
withGroupingKey.newColumns)
@@ -1387,7 +1399,11 @@ class Dataset[T] private[sql](
packageNames: Array[Byte],
broadcastVars: Array[Broadcast[Object]],
schema: StructType): DataFrame = {
- val rowEncoder = encoder.asInstanceOf[ExpressionEncoder[Row]]
+ val rowEncoder: ExpressionEncoder[Row] = if (isUnTyped) {
+ exprEnc.asInstanceOf[ExpressionEncoder[Row]]
+ } else {
+ ExpressionEncoder(schema)
+ }
Dataset.ofRows(
sparkSession,
MapPartitionsInR(func, packageNames, broadcastVars, schema, rowEncoder,
logicalPlan))
@@ -2237,7 +2253,7 @@ class Dataset[T] private[sql](
/** A convenient function to wrap a set based logical plan and produce a
Dataset. */
@inline private def withSetOperator[U : Encoder](logicalPlan: LogicalPlan):
Dataset[U] = {
- if (classTag.runtimeClass.isAssignableFrom(classOf[Row])) {
+ if (isUnTyped) {
// Set operators widen types (change the schema), so we cannot reuse the
row encoder.
Dataset.ofRows(sparkSession, logicalPlan).asInstanceOf[Dataset[U]]
} else {
@@ -2245,6 +2261,8 @@ class Dataset[T] private[sql](
}
}
+ private def isUnTyped: Boolean =
classTag.runtimeClass.isAssignableFrom(classOf[Row])
+
/** Returns a optimized plan for CommandResult, convert to `LocalRelation`.
*/
private def commandResultOptimized: Dataset[T] = {
logicalPlan match {
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
index 1ebdd57f1962..fcad1b721eac 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
@@ -19,7 +19,8 @@ package org.apache.spark.sql
import org.apache.spark.api.java.function._
import org.apache.spark.sql.catalyst.analysis.{EliminateEventTimeWatermark,
UnresolvedAttribute}
-import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder}
+import
org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{agnosticEncoderFor,
ProductEncoder}
+import org.apache.spark.sql.catalyst.encoders.encoderFor
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.QueryExecution
@@ -43,9 +44,8 @@ class KeyValueGroupedDataset[K, V] private[sql](
extends api.KeyValueGroupedDataset[K, V, Dataset] {
type KVDS[KY, VL] = KeyValueGroupedDataset[KY, VL]
- // Similar to [[Dataset]], we turn the passed in encoder to
`ExpressionEncoder` explicitly.
- private implicit val kExprEnc: ExpressionEncoder[K] = encoderFor(kEncoder)
- private implicit val vExprEnc: ExpressionEncoder[V] = encoderFor(vEncoder)
+ private implicit def kEncoderImpl: Encoder[K] = kEncoder
+ private implicit def vEncoderImpl: Encoder[V] = vEncoder
private def logicalPlan = queryExecution.analyzed
private def sparkSession = queryExecution.sparkSession
@@ -54,8 +54,8 @@ class KeyValueGroupedDataset[K, V] private[sql](
/** @inheritdoc */
def keyAs[L : Encoder]: KeyValueGroupedDataset[L, V] =
new KeyValueGroupedDataset(
- encoderFor[L],
- vExprEnc,
+ implicitly[Encoder[L]],
+ vEncoder,
queryExecution,
dataAttributes,
groupingAttributes)
@@ -67,8 +67,8 @@ class KeyValueGroupedDataset[K, V] private[sql](
val executed = sparkSession.sessionState.executePlan(projected)
new KeyValueGroupedDataset(
- encoderFor[K],
- encoderFor[W],
+ kEncoder,
+ implicitly[Encoder[W]],
executed,
withNewData.newColumns,
groupingAttributes)
@@ -297,20 +297,21 @@ class KeyValueGroupedDataset[K, V] private[sql](
/** @inheritdoc */
def reduceGroups(f: (V, V) => V): Dataset[(K, V)] = {
- val vEncoder = encoderFor[V]
val aggregator: TypedColumn[V, V] = new
ReduceAggregator[V](f)(vEncoder).toColumn
agg(aggregator)
}
/** @inheritdoc */
protected def aggUntyped(columns: TypedColumn[_, _]*): Dataset[_] = {
- val encoders = columns.map(c => encoderFor(c.encoder))
- val namedColumns = columns.map(c => withInputType(c.named, vExprEnc,
dataAttributes))
- val keyColumn = aggKeyColumn(kExprEnc, groupingAttributes)
+ val keyAgEncoder = agnosticEncoderFor(kEncoder)
+ val valueExprEncoder = encoderFor(vEncoder)
+ val encoders = columns.map(c => agnosticEncoderFor(c.encoder))
+ val namedColumns = columns.map { c =>
+ withInputType(c.named, valueExprEncoder, dataAttributes)
+ }
+ val keyColumn = aggKeyColumn(keyAgEncoder, groupingAttributes)
val aggregate = Aggregate(groupingAttributes, keyColumn +: namedColumns,
logicalPlan)
- val execution = new QueryExecution(sparkSession, aggregate)
-
- new Dataset(execution, ExpressionEncoder.tuple(kExprEnc +: encoders))
+ new Dataset(sparkSession, aggregate, ProductEncoder.tuple(keyAgEncoder +:
encoders))
}
/** @inheritdoc */
@@ -319,7 +320,7 @@ class KeyValueGroupedDataset[K, V] private[sql](
thisSortExprs: Column*)(
otherSortExprs: Column*)(
f: (K, Iterator[V], Iterator[U]) => IterableOnce[R]): Dataset[R] = {
- implicit val uEncoder = other.vExprEnc
+ implicit val uEncoder = other.vEncoderImpl
Dataset[R](
sparkSession,
CoGroup(
@@ -336,10 +337,10 @@ class KeyValueGroupedDataset[K, V] private[sql](
override def toString: String = {
val builder = new StringBuilder
- val kFields = kExprEnc.schema.map { f =>
+ val kFields = kEncoder.schema.map { f =>
s"${f.name}: ${f.dataType.simpleString(2)}"
}
- val vFields = vExprEnc.schema.map { f =>
+ val vFields = vEncoder.schema.map { f =>
s"${f.name}: ${f.dataType.simpleString(2)}"
}
builder.append("KeyValueGroupedDataset: [key: [")
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
index 4e4454018e81..da4609135fd6 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
@@ -22,7 +22,6 @@ import org.apache.spark.annotation.Stable
import org.apache.spark.api.python.PythonEvalType
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.sql.catalyst.analysis.UnresolvedAlias
-import org.apache.spark.sql.catalyst.encoders.encoderFor
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.logical._
@@ -119,15 +118,12 @@ class RelationalGroupedDataset protected[sql](
/** @inheritdoc */
def as[K: Encoder, T: Encoder]: KeyValueGroupedDataset[K, T] = {
- val keyEncoder = encoderFor[K]
- val valueEncoder = encoderFor[T]
-
val (qe, groupingAttributes) =
handleGroupingExpression(df.logicalPlan, df.sparkSession, groupingExprs)
new KeyValueGroupedDataset(
- keyEncoder,
- valueEncoder,
+ implicitly[Encoder[K]],
+ implicitly[Encoder[T]],
qe,
df.logicalPlan.output,
groupingAttributes)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
index 3832d7304407..09d9915022a6 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
@@ -504,7 +504,7 @@ case class ScalaAggregator[IN, BUF, OUT](
private[this] lazy val inputDeserializer = inputEncoder.createDeserializer()
private[this] lazy val bufferSerializer = bufferEncoder.createSerializer()
private[this] lazy val bufferDeserializer =
bufferEncoder.createDeserializer()
- private[this] lazy val outputEncoder =
agg.outputEncoder.asInstanceOf[ExpressionEncoder[OUT]]
+ private[this] lazy val outputEncoder = encoderFor(agg.outputEncoder)
private[this] lazy val outputSerializer = outputEncoder.createSerializer()
def dataType: DataType = outputEncoder.objSerializer.dataType
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTextSocketSource.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTextSocketSource.scala
index 420c3e3be16d..273ffa6aefb7 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTextSocketSource.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTextSocketSource.scala
@@ -32,8 +32,9 @@ import org.apache.spark.SparkEnv
import org.apache.spark.internal.{Logging, MDC}
import org.apache.spark.internal.LogKeys.{HOST, PORT}
import org.apache.spark.rpc.RpcEndpointRef
+import org.apache.spark.sql.Encoders
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.catalyst.encoders.encoderFor
import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.connector.read.InputPartition
import
org.apache.spark.sql.connector.read.streaming.{ContinuousPartitionReader,
ContinuousPartitionReaderFactory, ContinuousStream, Offset, PartitionOffset}
@@ -57,8 +58,7 @@ class TextSocketContinuousStream(
implicit val defaultFormats: DefaultFormats = DefaultFormats
- private val encoder = ExpressionEncoder.tuple(ExpressionEncoder[String](),
- ExpressionEncoder[Timestamp]())
+ private val encoder = encoderFor(Encoders.tuple(Encoders.STRING,
Encoders.TIMESTAMP))
@GuardedBy("this")
private var socket: Socket = _
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/expressions/ReduceAggregator.scala
b/sql/core/src/main/scala/org/apache/spark/sql/expressions/ReduceAggregator.scala
index fd3df372a2d5..192b5bf65c4c 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/expressions/ReduceAggregator.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/expressions/ReduceAggregator.scala
@@ -19,7 +19,8 @@ package org.apache.spark.sql.expressions
import org.apache.spark.SparkException
import org.apache.spark.sql.Encoder
-import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
+import
org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{PrimitiveBooleanEncoder,
ProductEncoder}
/**
* An aggregator that uses a single associative and commutative reduce
function. This reduce
@@ -46,10 +47,10 @@ private[sql] class ReduceAggregator[T: Encoder](func: (T,
T) => T)
override def zero: (Boolean, T) = (false, _zero.asInstanceOf[T])
- override def bufferEncoder: Encoder[(Boolean, T)] =
- ExpressionEncoder.tuple(
- ExpressionEncoder[Boolean](),
- encoder.asInstanceOf[ExpressionEncoder[T]])
+ override def bufferEncoder: Encoder[(Boolean, T)] = {
+ ProductEncoder.tuple(Seq(PrimitiveBooleanEncoder,
encoder.asInstanceOf[AgnosticEncoder[T]]))
+ .asInstanceOf[Encoder[(Boolean, T)]]
+ }
override def outputEncoder: Encoder[T] = encoder
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/internal/TypedAggUtils.scala
b/sql/core/src/main/scala/org/apache/spark/sql/internal/TypedAggUtils.scala
index b6340a35e770..23ceb8135fa8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/TypedAggUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/TypedAggUtils.scala
@@ -17,7 +17,9 @@
package org.apache.spark.sql.internal
+import org.apache.spark.sql.Encoder
import org.apache.spark.sql.catalyst.analysis.UnresolvedDeserializer
+import
org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.agnosticEncoderFor
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression
@@ -25,10 +27,10 @@ import
org.apache.spark.sql.execution.aggregate.TypedAggregateExpression
private[sql] object TypedAggUtils {
def aggKeyColumn[A](
- encoder: ExpressionEncoder[A],
+ encoder: Encoder[A],
groupingAttributes: Seq[Attribute]): NamedExpression = {
- if (!encoder.isSerializedAsStructForTopLevel) {
- assert(groupingAttributes.length == 1)
+ val agnosticEncoder = agnosticEncoderFor(encoder)
+ if (!agnosticEncoder.isStruct) {
if (SQLConf.get.nameNonStructGroupingKeyAsValue) {
groupingAttributes.head
} else {
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelperSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelperSuite.scala
index ea6fd8ab312c..2456999b4382 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelperSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelperSuite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.streaming.state
import java.util.concurrent.atomic.AtomicInteger
import org.apache.spark.sql.Encoder
-import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder}
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow,
UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.execution.streaming.GroupStateImpl._
import org.apache.spark.sql.streaming.StreamTest
@@ -201,7 +201,7 @@ class FlatMapGroupsWithStateExecHelperSuite extends
StreamTest {
private def newStateManager[T: Encoder](version: Int, withTimestamp:
Boolean): StateManager = {
FlatMapGroupsWithStateExecHelper.createStateManager(
- implicitly[Encoder[T]].asInstanceOf[ExpressionEncoder[Any]],
+ encoderFor[T].asInstanceOf[ExpressionEncoder[Any]],
withTimestamp,
version)
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ListStateSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ListStateSuite.scala
index add12f7e1535..e9300464af8d 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ListStateSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ListStateSuite.scala
@@ -22,7 +22,7 @@ import java.util.UUID
import org.apache.spark.{SparkIllegalArgumentException,
SparkUnsupportedOperationException}
import org.apache.spark.sql.Encoders
-import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder}
import org.apache.spark.sql.execution.streaming.{ImplicitGroupingKeyTracker,
ListStateImplWithTTL, StatefulProcessorHandleImpl}
import org.apache.spark.sql.streaming.{ListState, TimeMode, TTLConfig,
ValueState}
@@ -38,7 +38,7 @@ class ListStateSuite extends StateVariableSuiteBase {
tryWithProviderResource(newStoreProviderWithStateVariable(true)) {
provider =>
val store = provider.getStore(0)
val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(),
- Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeMode.None())
+ stringEncoder, TimeMode.None())
val listState: ListState[Long] = handle.getListState[Long]("listState",
Encoders.scalaLong)
@@ -71,7 +71,7 @@ class ListStateSuite extends StateVariableSuiteBase {
tryWithProviderResource(newStoreProviderWithStateVariable(true)) {
provider =>
val store = provider.getStore(0)
val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(),
- Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeMode.None())
+ stringEncoder, TimeMode.None())
val testState: ListState[Long] = handle.getListState[Long]("testState",
Encoders.scalaLong)
ImplicitGroupingKeyTracker.setImplicitKey("test_key")
@@ -99,7 +99,7 @@ class ListStateSuite extends StateVariableSuiteBase {
tryWithProviderResource(newStoreProviderWithStateVariable(true)) {
provider =>
val store = provider.getStore(0)
val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(),
- Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeMode.None())
+ stringEncoder, TimeMode.None())
val testState1: ListState[Long] =
handle.getListState[Long]("testState1", Encoders.scalaLong)
val testState2: ListState[Long] =
handle.getListState[Long]("testState2", Encoders.scalaLong)
@@ -137,7 +137,7 @@ class ListStateSuite extends StateVariableSuiteBase {
tryWithProviderResource(newStoreProviderWithStateVariable(true)) {
provider =>
val store = provider.getStore(0)
val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(),
- Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeMode.None())
+ stringEncoder, TimeMode.None())
val listState1: ListState[Long] =
handle.getListState[Long]("listState1", Encoders.scalaLong)
val listState2: ListState[Long] =
handle.getListState[Long]("listState2", Encoders.scalaLong)
@@ -167,7 +167,7 @@ class ListStateSuite extends StateVariableSuiteBase {
val store = provider.getStore(0)
val timestampMs = 10
val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(),
- Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]],
+ stringEncoder,
TimeMode.ProcessingTime(), batchTimestampMs = Some(timestampMs))
val ttlConfig = TTLConfig(ttlDuration = Duration.ofMinutes(1))
@@ -187,7 +187,7 @@ class ListStateSuite extends StateVariableSuiteBase {
// increment batchProcessingTime, or watermark and ensure expired value
is not returned
val nextBatchHandle = new StatefulProcessorHandleImpl(store,
UUID.randomUUID(),
- Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]],
+ stringEncoder,
TimeMode.ProcessingTime(), batchTimestampMs = Some(ttlExpirationMs))
val nextBatchTestState: ListStateImplWithTTL[String] =
@@ -223,7 +223,7 @@ class ListStateSuite extends StateVariableSuiteBase {
val store = provider.getStore(0)
val batchTimestampMs = 10
val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(),
- Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]],
+ stringEncoder,
TimeMode.ProcessingTime(), batchTimestampMs = Some(batchTimestampMs))
Seq(null, Duration.ZERO, Duration.ofMinutes(-1)).foreach { ttlDuration =>
@@ -250,7 +250,7 @@ class ListStateSuite extends StateVariableSuiteBase {
val store = provider.getStore(0)
val timestampMs = 10
val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(),
-
Encoders.bean(classOf[POJOTestClass]).asInstanceOf[ExpressionEncoder[Any]],
+
encoderFor(Encoders.bean(classOf[POJOTestClass])).asInstanceOf[ExpressionEncoder[Any]],
TimeMode.ProcessingTime(), batchTimestampMs = Some(timestampMs))
val ttlConfig = TTLConfig(ttlDuration = Duration.ofMinutes(1))
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MapStateSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MapStateSuite.scala
index 9c322b201da8..b067d589de90 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MapStateSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MapStateSuite.scala
@@ -22,7 +22,6 @@ import java.util.UUID
import org.apache.spark.SparkUnsupportedOperationException
import org.apache.spark.sql.Encoders
-import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.execution.streaming.{ImplicitGroupingKeyTracker,
MapStateImplWithTTL, StatefulProcessorHandleImpl}
import org.apache.spark.sql.streaming.{ListState, MapState, TimeMode,
TTLConfig, ValueState}
import org.apache.spark.sql.types.{BinaryType, StructType}
@@ -41,7 +40,7 @@ class MapStateSuite extends StateVariableSuiteBase {
tryWithProviderResource(newStoreProviderWithStateVariable(true)) {
provider =>
val store = provider.getStore(0)
val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(),
- Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeMode.None())
+ stringEncoder, TimeMode.None())
val testState: MapState[String, Double] =
handle.getMapState[String, Double]("testState", Encoders.STRING,
Encoders.scalaDouble)
@@ -75,7 +74,7 @@ class MapStateSuite extends StateVariableSuiteBase {
tryWithProviderResource(newStoreProviderWithStateVariable(true)) {
provider =>
val store = provider.getStore(0)
val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(),
- Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeMode.None())
+ stringEncoder, TimeMode.None())
val testState1: MapState[Long, Double] =
handle.getMapState[Long, Double]("testState1", Encoders.scalaLong,
Encoders.scalaDouble)
@@ -114,7 +113,7 @@ class MapStateSuite extends StateVariableSuiteBase {
tryWithProviderResource(newStoreProviderWithStateVariable(true)) {
provider =>
val store = provider.getStore(0)
val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(),
- Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeMode.None())
+ stringEncoder, TimeMode.None())
val mapTestState1: MapState[String, Int] =
handle.getMapState[String, Int]("mapTestState1", Encoders.STRING,
Encoders.scalaInt)
@@ -175,7 +174,7 @@ class MapStateSuite extends StateVariableSuiteBase {
val store = provider.getStore(0)
val timestampMs = 10
val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(),
- Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]],
TimeMode.ProcessingTime(),
+ stringEncoder, TimeMode.ProcessingTime(),
batchTimestampMs = Some(timestampMs))
val ttlConfig = TTLConfig(ttlDuration = Duration.ofMinutes(1))
@@ -196,7 +195,7 @@ class MapStateSuite extends StateVariableSuiteBase {
// increment batchProcessingTime, or watermark and ensure expired value
is not returned
val nextBatchHandle = new StatefulProcessorHandleImpl(store,
UUID.randomUUID(),
- Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]],
+ stringEncoder,
TimeMode.ProcessingTime(), batchTimestampMs = Some(ttlExpirationMs))
val nextBatchTestState: MapStateImplWithTTL[String, String] =
@@ -233,7 +232,7 @@ class MapStateSuite extends StateVariableSuiteBase {
val store = provider.getStore(0)
val batchTimestampMs = 10
val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(),
- Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]],
+ stringEncoder,
TimeMode.ProcessingTime(), batchTimestampMs = Some(batchTimestampMs))
Seq(null, Duration.ZERO, Duration.ofMinutes(-1)).foreach { ttlDuration =>
@@ -261,7 +260,7 @@ class MapStateSuite extends StateVariableSuiteBase {
val store = provider.getStore(0)
val timestampMs = 10
val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(),
- Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]],
TimeMode.ProcessingTime(),
+ stringEncoder, TimeMode.ProcessingTime(),
batchTimestampMs = Some(timestampMs))
val ttlConfig = TTLConfig(ttlDuration = Duration.ofMinutes(1))
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatefulProcessorHandleSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatefulProcessorHandleSuite.scala
index e2940497e911..48a6fd836a46 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatefulProcessorHandleSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatefulProcessorHandleSuite.scala
@@ -22,7 +22,6 @@ import java.util.UUID
import org.apache.spark.SparkUnsupportedOperationException
import org.apache.spark.sql.Encoders
-import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.execution.streaming.{ImplicitGroupingKeyTracker,
StatefulProcessorHandleImpl, StatefulProcessorHandleState}
import org.apache.spark.sql.streaming.{TimeMode, TTLConfig}
@@ -33,9 +32,6 @@ import org.apache.spark.sql.streaming.{TimeMode, TTLConfig}
*/
class StatefulProcessorHandleSuite extends StateVariableSuiteBase {
- private def keyExprEncoder: ExpressionEncoder[Any] =
- Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]]
-
private def getTimeMode(timeMode: String): TimeMode = {
timeMode match {
case "None" => TimeMode.None()
@@ -50,7 +46,7 @@ class StatefulProcessorHandleSuite extends
StateVariableSuiteBase {
tryWithProviderResource(newStoreProviderWithStateVariable(true)) {
provider =>
val store = provider.getStore(0)
val handle = new StatefulProcessorHandleImpl(store,
- UUID.randomUUID(), keyExprEncoder, getTimeMode(timeMode))
+ UUID.randomUUID(), stringEncoder, getTimeMode(timeMode))
assert(handle.getHandleState === StatefulProcessorHandleState.CREATED)
handle.getValueState[Long]("testState", Encoders.scalaLong)
}
@@ -91,7 +87,7 @@ class StatefulProcessorHandleSuite extends
StateVariableSuiteBase {
tryWithProviderResource(newStoreProviderWithStateVariable(true)) {
provider =>
val store = provider.getStore(0)
val handle = new StatefulProcessorHandleImpl(store,
- UUID.randomUUID(), keyExprEncoder, getTimeMode(timeMode))
+ UUID.randomUUID(), stringEncoder, getTimeMode(timeMode))
Seq(StatefulProcessorHandleState.INITIALIZED,
StatefulProcessorHandleState.DATA_PROCESSED,
@@ -109,7 +105,7 @@ class StatefulProcessorHandleSuite extends
StateVariableSuiteBase {
tryWithProviderResource(newStoreProviderWithStateVariable(true)) {
provider =>
val store = provider.getStore(0)
val handle = new StatefulProcessorHandleImpl(store,
- UUID.randomUUID(), keyExprEncoder, TimeMode.None())
+ UUID.randomUUID(), stringEncoder, TimeMode.None())
val ex = intercept[SparkUnsupportedOperationException] {
handle.registerTimer(10000L)
}
@@ -145,7 +141,7 @@ class StatefulProcessorHandleSuite extends
StateVariableSuiteBase {
tryWithProviderResource(newStoreProviderWithStateVariable(true)) {
provider =>
val store = provider.getStore(0)
val handle = new StatefulProcessorHandleImpl(store,
- UUID.randomUUID(), keyExprEncoder, getTimeMode(timeMode))
+ UUID.randomUUID(), stringEncoder, getTimeMode(timeMode))
handle.setHandleState(StatefulProcessorHandleState.INITIALIZED)
assert(handle.getHandleState ===
StatefulProcessorHandleState.INITIALIZED)
@@ -166,7 +162,7 @@ class StatefulProcessorHandleSuite extends
StateVariableSuiteBase {
tryWithProviderResource(newStoreProviderWithStateVariable(true)) {
provider =>
val store = provider.getStore(0)
val handle = new StatefulProcessorHandleImpl(store,
- UUID.randomUUID(), keyExprEncoder, getTimeMode(timeMode))
+ UUID.randomUUID(), stringEncoder, getTimeMode(timeMode))
handle.setHandleState(StatefulProcessorHandleState.DATA_PROCESSED)
assert(handle.getHandleState ===
StatefulProcessorHandleState.DATA_PROCESSED)
@@ -206,7 +202,7 @@ class StatefulProcessorHandleSuite extends
StateVariableSuiteBase {
tryWithProviderResource(newStoreProviderWithStateVariable(true)) {
provider =>
val store = provider.getStore(0)
val handle = new StatefulProcessorHandleImpl(store,
- UUID.randomUUID(), keyExprEncoder, getTimeMode(timeMode))
+ UUID.randomUUID(), stringEncoder, getTimeMode(timeMode))
Seq(StatefulProcessorHandleState.CREATED,
StatefulProcessorHandleState.TIMER_PROCESSED,
@@ -223,7 +219,7 @@ class StatefulProcessorHandleSuite extends
StateVariableSuiteBase {
tryWithProviderResource(newStoreProviderWithStateVariable(true)) {
provider =>
val store = provider.getStore(0)
val handle = new StatefulProcessorHandleImpl(store,
- UUID.randomUUID(), keyExprEncoder, TimeMode.ProcessingTime(),
+ UUID.randomUUID(), stringEncoder, TimeMode.ProcessingTime(),
batchTimestampMs = Some(10))
val valueStateWithTTL = handle.getValueState("testState",
@@ -241,7 +237,7 @@ class StatefulProcessorHandleSuite extends
StateVariableSuiteBase {
tryWithProviderResource(newStoreProviderWithStateVariable(true)) {
provider =>
val store = provider.getStore(0)
val handle = new StatefulProcessorHandleImpl(store,
- UUID.randomUUID(), keyExprEncoder, TimeMode.ProcessingTime(),
+ UUID.randomUUID(), stringEncoder, TimeMode.ProcessingTime(),
batchTimestampMs = Some(10))
val listStateWithTTL = handle.getListState("testState",
@@ -259,7 +255,7 @@ class StatefulProcessorHandleSuite extends
StateVariableSuiteBase {
tryWithProviderResource(newStoreProviderWithStateVariable(true)) {
provider =>
val store = provider.getStore(0)
val handle = new StatefulProcessorHandleImpl(store,
- UUID.randomUUID(), keyExprEncoder, TimeMode.ProcessingTime(),
+ UUID.randomUUID(), stringEncoder, TimeMode.ProcessingTime(),
batchTimestampMs = Some(10))
val mapStateWithTTL = handle.getMapState("testState",
@@ -277,7 +273,7 @@ class StatefulProcessorHandleSuite extends
StateVariableSuiteBase {
tryWithProviderResource(newStoreProviderWithStateVariable(true)) {
provider =>
val store = provider.getStore(0)
val handle = new StatefulProcessorHandleImpl(store,
- UUID.randomUUID(), keyExprEncoder, TimeMode.None())
+ UUID.randomUUID(), stringEncoder, TimeMode.None())
handle.getValueState("testValueState", Encoders.STRING)
handle.getListState("testListState", Encoders.STRING)
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/TimerSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/TimerSuite.scala
index df6a3fd7b23e..24a120be9d9a 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/TimerSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/TimerSuite.scala
@@ -18,7 +18,7 @@
package org.apache.spark.sql.execution.streaming.state
import org.apache.spark.sql.Encoders
-import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder}
import org.apache.spark.sql.execution.streaming.{ImplicitGroupingKeyTracker,
TimerStateImpl}
import org.apache.spark.sql.streaming.TimeMode
@@ -45,7 +45,7 @@ class TimerSuite extends StateVariableSuiteBase {
ImplicitGroupingKeyTracker.setImplicitKey("test_key")
val timerState = new TimerStateImpl(store, timeMode,
- Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]])
+ stringEncoder)
timerState.registerTimer(1L * 1000)
assert(timerState.listTimers().toSet === Set(1000L))
assert(timerState.getExpiredTimers(Long.MaxValue).toSeq ===
Seq(("test_key", 1000L)))
@@ -64,9 +64,9 @@ class TimerSuite extends StateVariableSuiteBase {
ImplicitGroupingKeyTracker.setImplicitKey("test_key")
val timerState1 = new TimerStateImpl(store, timeMode,
- Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]])
+ stringEncoder)
val timerState2 = new TimerStateImpl(store, timeMode,
- Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]])
+ stringEncoder)
timerState1.registerTimer(1L * 1000)
timerState2.registerTimer(15L * 1000)
assert(timerState1.listTimers().toSet === Set(15000L, 1000L))
@@ -89,7 +89,7 @@ class TimerSuite extends StateVariableSuiteBase {
ImplicitGroupingKeyTracker.setImplicitKey("test_key1")
val timerState1 = new TimerStateImpl(store, timeMode,
- Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]])
+ stringEncoder)
timerState1.registerTimer(1L * 1000)
timerState1.registerTimer(2L * 1000)
assert(timerState1.listTimers().toSet === Set(1000L, 2000L))
@@ -97,7 +97,7 @@ class TimerSuite extends StateVariableSuiteBase {
ImplicitGroupingKeyTracker.setImplicitKey("test_key2")
val timerState2 = new TimerStateImpl(store, timeMode,
- Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]])
+ stringEncoder)
timerState2.registerTimer(15L * 1000)
ImplicitGroupingKeyTracker.removeImplicitKey()
@@ -122,7 +122,7 @@ class TimerSuite extends StateVariableSuiteBase {
ImplicitGroupingKeyTracker.setImplicitKey("test_key")
val timerState = new TimerStateImpl(store, timeMode,
- Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]])
+ stringEncoder)
val timerTimerstamps = Seq(931L, 8000L, 452300L, 4200L, 90L, 1L, 2L, 8L,
3L, 35L, 6L, 9L, 5L)
// register/put unordered timestamp into rocksDB
timerTimerstamps.foreach(timerState.registerTimer)
@@ -141,19 +141,19 @@ class TimerSuite extends StateVariableSuiteBase {
ImplicitGroupingKeyTracker.setImplicitKey("test_key1")
val timerState1 = new TimerStateImpl(store, timeMode,
- Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]])
+ stringEncoder)
val timerTimestamps1 = Seq(64L, 32L, 1024L, 4096L, 0L, 1L)
timerTimestamps1.foreach(timerState1.registerTimer)
val timerState2 = new TimerStateImpl(store, timeMode,
- Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]])
+ stringEncoder)
val timerTimestamps2 = Seq(931L, 8000L, 452300L, 4200L)
timerTimestamps2.foreach(timerState2.registerTimer)
ImplicitGroupingKeyTracker.removeImplicitKey()
ImplicitGroupingKeyTracker.setImplicitKey("test_key3")
val timerState3 = new TimerStateImpl(store, timeMode,
- Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]])
+ stringEncoder)
val timerTimerStamps3 = Seq(1L, 2L, 8L, 3L)
timerTimerStamps3.foreach(timerState3.registerTimer)
ImplicitGroupingKeyTracker.removeImplicitKey()
@@ -171,7 +171,7 @@ class TimerSuite extends StateVariableSuiteBase {
val store = provider.getStore(0)
ImplicitGroupingKeyTracker.setImplicitKey(TestClass(1L, "k1"))
val timerState = new TimerStateImpl(store, timeMode,
- Encoders.product[TestClass].asInstanceOf[ExpressionEncoder[Any]])
+
encoderFor(Encoders.product[TestClass]).asInstanceOf[ExpressionEncoder[Any]])
timerState.registerTimer(1L * 1000)
assert(timerState.listTimers().toSet === Set(1000L))
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala
index 41912a4dda23..13d758eb1b88 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala
@@ -27,7 +27,7 @@ import org.scalatest.BeforeAndAfter
import org.apache.spark.{SparkException, SparkUnsupportedOperationException}
import org.apache.spark.sql.Encoders
-import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder}
import org.apache.spark.sql.execution.streaming.{ImplicitGroupingKeyTracker,
StatefulProcessorHandleImpl, ValueStateImplWithTTL}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming.{TimeMode, TTLConfig, ValueState}
@@ -49,7 +49,7 @@ class ValueStateSuite extends StateVariableSuiteBase {
tryWithProviderResource(newStoreProviderWithStateVariable(true)) {
provider =>
val store = provider.getStore(0)
val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(),
- Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeMode.None())
+ stringEncoder, TimeMode.None())
val stateName = "testState"
val testState: ValueState[Long] =
handle.getValueState[Long]("testState", Encoders.scalaLong)
@@ -93,7 +93,7 @@ class ValueStateSuite extends StateVariableSuiteBase {
tryWithProviderResource(newStoreProviderWithStateVariable(true)) {
provider =>
val store = provider.getStore(0)
val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(),
- Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeMode.None())
+ stringEncoder, TimeMode.None())
val testState: ValueState[Long] =
handle.getValueState[Long]("testState", Encoders.scalaLong)
ImplicitGroupingKeyTracker.setImplicitKey("test_key")
@@ -119,7 +119,7 @@ class ValueStateSuite extends StateVariableSuiteBase {
tryWithProviderResource(newStoreProviderWithStateVariable(true)) {
provider =>
val store = provider.getStore(0)
val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(),
- Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeMode.None())
+ stringEncoder, TimeMode.None())
val testState1: ValueState[Long] = handle.getValueState[Long](
"testState1", Encoders.scalaLong)
@@ -164,7 +164,7 @@ class ValueStateSuite extends StateVariableSuiteBase {
tryWithProviderResource(newStoreProviderWithStateVariable(true)) {
provider =>
val store = provider.getStore(0)
val handle = new StatefulProcessorHandleImpl(store,
- UUID.randomUUID(),
Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeMode.None())
+ UUID.randomUUID(), stringEncoder, TimeMode.None())
val cfName = "$testState"
val ex = intercept[SparkUnsupportedOperationException] {
@@ -204,7 +204,7 @@ class ValueStateSuite extends StateVariableSuiteBase {
tryWithProviderResource(newStoreProviderWithStateVariable(true)) {
provider =>
val store = provider.getStore(0)
val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(),
- Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeMode.None())
+ stringEncoder, TimeMode.None())
val testState: ValueState[Double] =
handle.getValueState[Double]("testState",
Encoders.scalaDouble)
@@ -230,7 +230,7 @@ class ValueStateSuite extends StateVariableSuiteBase {
tryWithProviderResource(newStoreProviderWithStateVariable(true)) {
provider =>
val store = provider.getStore(0)
val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(),
- Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeMode.None())
+ stringEncoder, TimeMode.None())
val testState: ValueState[Long] = handle.getValueState[Long]("testState",
Encoders.scalaLong)
@@ -256,7 +256,7 @@ class ValueStateSuite extends StateVariableSuiteBase {
tryWithProviderResource(newStoreProviderWithStateVariable(true)) {
provider =>
val store = provider.getStore(0)
val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(),
- Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeMode.None())
+ stringEncoder, TimeMode.None())
val testState: ValueState[TestClass] =
handle.getValueState[TestClass]("testState",
Encoders.product[TestClass])
@@ -282,7 +282,7 @@ class ValueStateSuite extends StateVariableSuiteBase {
tryWithProviderResource(newStoreProviderWithStateVariable(true)) {
provider =>
val store = provider.getStore(0)
val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(),
- Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeMode.None())
+ stringEncoder, TimeMode.None())
val testState: ValueState[POJOTestClass] =
handle.getValueState[POJOTestClass]("testState",
Encoders.bean(classOf[POJOTestClass]))
@@ -310,7 +310,7 @@ class ValueStateSuite extends StateVariableSuiteBase {
val store = provider.getStore(0)
val timestampMs = 10
val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(),
- Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]],
TimeMode.ProcessingTime(),
+ stringEncoder, TimeMode.ProcessingTime(),
batchTimestampMs = Some(timestampMs))
val ttlConfig = TTLConfig(ttlDuration = Duration.ofMinutes(1))
@@ -330,7 +330,7 @@ class ValueStateSuite extends StateVariableSuiteBase {
// increment batchProcessingTime, or watermark and ensure expired value
is not returned
val nextBatchHandle = new StatefulProcessorHandleImpl(store,
UUID.randomUUID(),
- Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]],
+ stringEncoder,
TimeMode.ProcessingTime(), batchTimestampMs = Some(ttlExpirationMs))
val nextBatchTestState: ValueStateImplWithTTL[String] =
@@ -366,7 +366,7 @@ class ValueStateSuite extends StateVariableSuiteBase {
val store = provider.getStore(0)
val batchTimestampMs = 10
val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(),
- Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]],
+ stringEncoder,
TimeMode.ProcessingTime(), batchTimestampMs = Some(batchTimestampMs))
Seq(null, Duration.ZERO, Duration.ofMinutes(-1)).foreach { ttlDuration =>
@@ -393,8 +393,8 @@ class ValueStateSuite extends StateVariableSuiteBase {
val store = provider.getStore(0)
val timestampMs = 10
val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(),
- Encoders.product[TestClass].asInstanceOf[ExpressionEncoder[Any]],
TimeMode.ProcessingTime(),
- batchTimestampMs = Some(timestampMs))
+
encoderFor(Encoders.product[TestClass]).asInstanceOf[ExpressionEncoder[Any]],
+ TimeMode.ProcessingTime(), batchTimestampMs = Some(timestampMs))
val ttlConfig = TTLConfig(ttlDuration = Duration.ofMinutes(1))
val testState: ValueStateImplWithTTL[POJOTestClass] =
@@ -437,6 +437,8 @@ abstract class StateVariableSuiteBase extends
SharedSparkSession
import StateStoreTestsHelper._
+ protected val stringEncoder =
encoderFor(Encoders.STRING).asInstanceOf[ExpressionEncoder[Any]]
+
// dummy schema for initializing rocksdb provider
protected def schemaForKeyRow: StructType = new StructType().add("key",
BinaryType)
protected def schemaForValueRow: StructType = new StructType().add("value",
BinaryType)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]