This is an automated email from the ASF dual-hosted git repository.

hvanhovell pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 31cd6991558c [SPARK-49424][CONNECT][SQL] Consolidate Encoders.scala
31cd6991558c is described below

commit 31cd6991558c45cea56ba25cb89f13e64e3d93fa
Author: Herman van Hovell <[email protected]>
AuthorDate: Tue Sep 17 23:00:49 2024 -0400

    [SPARK-49424][CONNECT][SQL] Consolidate Encoders.scala
    
    ### What changes were proposed in this pull request?
    This PR moves Encoders.scala to sql/api. It removes the duplicate one in 
connect.
    
    ### Why are the changes needed?
    We are creating a unified scala interface for Classic and Connect.
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    Existing tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No.
    
    Closes #48021 from hvanhovell/SPARK-49424.
    
    Authored-by: Herman van Hovell <[email protected]>
    Signed-off-by: Herman van Hovell <[email protected]>
---
 .../src/main/resources/error/error-conditions.json |   6 +
 project/MimaExcludes.scala                         |   4 +
 .../main/scala/org/apache/spark/sql/Encoders.scala | 152 +++++----
 .../sql/catalyst/encoders/AgnosticEncoder.scala    |  13 +
 .../apache/spark/sql/errors/ExecutionErrors.scala  |  18 ++
 .../main/scala/org/apache/spark/sql/Encoders.scala | 348 ---------------------
 .../sql/catalyst/encoders/ExpressionEncoder.scala  |  51 +--
 .../spark/sql/errors/QueryExecutionErrors.scala    |  11 -
 .../catalyst/encoders/EncoderResolutionSuite.scala |   8 +-
 .../catalyst/encoders/ExpressionEncoderSuite.scala |  23 +-
 .../sql/connect/planner/SparkConnectPlanner.scala  |  17 +-
 .../main/scala/org/apache/spark/sql/Dataset.scala  |  84 +++--
 .../apache/spark/sql/KeyValueGroupedDataset.scala  |  37 +--
 .../spark/sql/RelationalGroupedDataset.scala       |   8 +-
 .../spark/sql/execution/aggregate/udaf.scala       |   2 +-
 .../continuous/ContinuousTextSocketSource.scala    |   6 +-
 .../spark/sql/expressions/ReduceAggregator.scala   |  11 +-
 .../apache/spark/sql/internal/TypedAggUtils.scala  |   8 +-
 .../FlatMapGroupsWithStateExecHelperSuite.scala    |   4 +-
 .../execution/streaming/state/ListStateSuite.scala |  18 +-
 .../execution/streaming/state/MapStateSuite.scala  |  15 +-
 .../state/StatefulProcessorHandleSuite.scala       |  24 +-
 .../sql/execution/streaming/state/TimerSuite.scala |  22 +-
 .../streaming/state/ValueStateSuite.scala          |  30 +-
 24 files changed, 300 insertions(+), 620 deletions(-)

diff --git a/common/utils/src/main/resources/error/error-conditions.json 
b/common/utils/src/main/resources/error/error-conditions.json
index 57b3d33741e9..25dd676c4aff 100644
--- a/common/utils/src/main/resources/error/error-conditions.json
+++ b/common/utils/src/main/resources/error/error-conditions.json
@@ -1996,6 +1996,12 @@
     },
     "sqlState" : "42903"
   },
+  "INVALID_AGNOSTIC_ENCODER" : {
+    "message" : [
+      "Found an invalid agnostic encoder. Expects an instance of 
AgnosticEncoder but got <encoderType>. For more information consult 
'<docroot>/api/java/index.html?org/apache/spark/sql/Encoder.html'."
+    ],
+    "sqlState" : "42001"
+  },
   "INVALID_ARRAY_INDEX" : {
     "message" : [
       "The index <indexValue> is out of bounds. The array has <arraySize> 
elements. Use the SQL function `get()` to tolerate accessing element at invalid 
index and return NULL instead. If necessary set <ansiConfig> to \"false\" to 
bypass this error."
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 68433b501bcc..dfe7b14e2ec6 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -160,6 +160,10 @@ object MimaExcludes {
     
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.DataFrameWriterV2"),
     
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.WriteConfigMethods"),
 
+    // SPARK-49424: Shared Encoders
+    
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.Encoders"),
+    
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.Encoders$"),
+
     // SPARK-49413: Create a shared RuntimeConfig interface.
     
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.RuntimeConfig"),
     
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.RuntimeConfig$"),
diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Encoders.scala
 b/sql/api/src/main/scala/org/apache/spark/sql/Encoders.scala
similarity index 77%
rename from 
connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Encoders.scala
rename to sql/api/src/main/scala/org/apache/spark/sql/Encoders.scala
index 33a322109c1b..9976b34f7a01 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Encoders.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/Encoders.scala
@@ -14,95 +14,99 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
+
 package org.apache.spark.sql
 
-import scala.reflect.ClassTag
+import java.lang.reflect.Modifier
+
+import scala.reflect.{classTag, ClassTag}
 import scala.reflect.runtime.universe.TypeTag
 
 import org.apache.spark.sql.catalyst.{JavaTypeInference, ScalaReflection}
-import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, 
JavaSerializationCodec, KryoSerializationCodec, RowEncoder => RowEncoderFactory}
+import org.apache.spark.sql.catalyst.encoders.{Codec, JavaSerializationCodec, 
KryoSerializationCodec, RowEncoder => SchemaInference}
 import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.errors.ExecutionErrors
+import org.apache.spark.sql.types._
 
 /**
  * Methods for creating an [[Encoder]].
  *
- * @since 3.5.0
+ * @since 1.6.0
  */
 object Encoders {
 
   /**
    * An encoder for nullable boolean type. The Scala primitive encoder is 
available as
    * [[scalaBoolean]].
-   * @since 3.5.0
+   * @since 1.6.0
    */
   def BOOLEAN: Encoder[java.lang.Boolean] = BoxedBooleanEncoder
 
   /**
    * An encoder for nullable byte type. The Scala primitive encoder is 
available as [[scalaByte]].
-   * @since 3.5.0
+   * @since 1.6.0
    */
   def BYTE: Encoder[java.lang.Byte] = BoxedByteEncoder
 
   /**
    * An encoder for nullable short type. The Scala primitive encoder is 
available as
    * [[scalaShort]].
-   * @since 3.5.0
+   * @since 1.6.0
    */
   def SHORT: Encoder[java.lang.Short] = BoxedShortEncoder
 
   /**
    * An encoder for nullable int type. The Scala primitive encoder is 
available as [[scalaInt]].
-   * @since 3.5.0
+   * @since 1.6.0
    */
   def INT: Encoder[java.lang.Integer] = BoxedIntEncoder
 
   /**
    * An encoder for nullable long type. The Scala primitive encoder is 
available as [[scalaLong]].
-   * @since 3.5.0
+   * @since 1.6.0
    */
   def LONG: Encoder[java.lang.Long] = BoxedLongEncoder
 
   /**
    * An encoder for nullable float type. The Scala primitive encoder is 
available as
    * [[scalaFloat]].
-   * @since 3.5.0
+   * @since 1.6.0
    */
   def FLOAT: Encoder[java.lang.Float] = BoxedFloatEncoder
 
   /**
    * An encoder for nullable double type. The Scala primitive encoder is 
available as
    * [[scalaDouble]].
-   * @since 3.5.0
+   * @since 1.6.0
    */
   def DOUBLE: Encoder[java.lang.Double] = BoxedDoubleEncoder
 
   /**
    * An encoder for nullable string type.
    *
-   * @since 3.5.0
+   * @since 1.6.0
    */
   def STRING: Encoder[java.lang.String] = StringEncoder
 
   /**
    * An encoder for nullable decimal type.
    *
-   * @since 3.5.0
+   * @since 1.6.0
    */
   def DECIMAL: Encoder[java.math.BigDecimal] = DEFAULT_JAVA_DECIMAL_ENCODER
 
   /**
    * An encoder for nullable date type.
    *
-   * @since 3.5.0
+   * @since 1.6.0
    */
-  def DATE: Encoder[java.sql.Date] = DateEncoder(lenientSerialization = false)
+  def DATE: Encoder[java.sql.Date] = STRICT_DATE_ENCODER
 
   /**
    * Creates an encoder that serializes instances of the `java.time.LocalDate` 
class to the
    * internal representation of nullable Catalyst's DateType.
    *
-   * @since 3.5.0
+   * @since 3.0.0
    */
   def LOCALDATE: Encoder[java.time.LocalDate] = STRICT_LOCAL_DATE_ENCODER
 
@@ -110,14 +114,14 @@ object Encoders {
    * Creates an encoder that serializes instances of the 
`java.time.LocalDateTime` class to the
    * internal representation of nullable Catalyst's TimestampNTZType.
    *
-   * @since 3.5.0
+   * @since 3.4.0
    */
   def LOCALDATETIME: Encoder[java.time.LocalDateTime] = LocalDateTimeEncoder
 
   /**
    * An encoder for nullable timestamp type.
    *
-   * @since 3.5.0
+   * @since 1.6.0
    */
   def TIMESTAMP: Encoder[java.sql.Timestamp] = STRICT_TIMESTAMP_ENCODER
 
@@ -125,14 +129,14 @@ object Encoders {
    * Creates an encoder that serializes instances of the `java.time.Instant` 
class to the internal
    * representation of nullable Catalyst's TimestampType.
    *
-   * @since 3.5.0
+   * @since 3.0.0
    */
   def INSTANT: Encoder[java.time.Instant] = STRICT_INSTANT_ENCODER
 
   /**
    * An encoder for arrays of bytes.
    *
-   * @since 3.5.0
+   * @since 1.6.1
    */
   def BINARY: Encoder[Array[Byte]] = BinaryEncoder
 
@@ -140,7 +144,7 @@ object Encoders {
    * Creates an encoder that serializes instances of the `java.time.Duration` 
class to the
    * internal representation of nullable Catalyst's DayTimeIntervalType.
    *
-   * @since 3.5.0
+   * @since 3.2.0
    */
   def DURATION: Encoder[java.time.Duration] = DayTimeIntervalEncoder
 
@@ -148,7 +152,7 @@ object Encoders {
    * Creates an encoder that serializes instances of the `java.time.Period` 
class to the internal
    * representation of nullable Catalyst's YearMonthIntervalType.
    *
-   * @since 3.5.0
+   * @since 3.2.0
    */
   def PERIOD: Encoder[java.time.Period] = YearMonthIntervalEncoder
 
@@ -166,7 +170,7 @@ object Encoders {
    *   - collection types: array, java.util.List, and map
    *   - nested java bean.
    *
-   * @since 3.5.0
+   * @since 1.6.0
    */
   def bean[T](beanClass: Class[T]): Encoder[T] = 
JavaTypeInference.encoderFor(beanClass)
 
@@ -175,71 +179,96 @@ object Encoders {
    *
    * @since 3.5.0
    */
-  def row(schema: StructType): Encoder[Row] = 
RowEncoderFactory.encoderFor(schema)
+  def row(schema: StructType): Encoder[Row] = 
SchemaInference.encoderFor(schema)
 
   /**
-   * (Scala-specific) Creates an encoder that serializes objects of type T 
using generic Java
-   * serialization. This encoder maps T into a single byte array (binary) 
field.
+   * (Scala-specific) Creates an encoder that serializes objects of type T 
using Kryo. This
+   * encoder maps T into a single byte array (binary) field.
    *
    * T must be publicly accessible.
    *
-   * @note
-   *   This is extremely inefficient and should only be used as the last 
resort.
-   * @since 4.0.0
+   * @since 1.6.0
    */
-  def javaSerialization[T: ClassTag]: Encoder[T] = {
-    TransformingEncoder(implicitly[ClassTag[T]], BinaryEncoder, 
JavaSerializationCodec)
-  }
+  def kryo[T: ClassTag]: Encoder[T] = genericSerializer(KryoSerializationCodec)
 
   /**
-   * Creates an encoder that serializes objects of type T using generic Java 
serialization. This
-   * encoder maps T into a single byte array (binary) field.
+   * Creates an encoder that serializes objects of type T using Kryo. This 
encoder maps T into a
+   * single byte array (binary) field.
+   *
+   * T must be publicly accessible.
+   *
+   * @since 1.6.0
+   */
+  def kryo[T](clazz: Class[T]): Encoder[T] = kryo(ClassTag[T](clazz))
+
+  /**
+   * (Scala-specific) Creates an encoder that serializes objects of type T 
using generic Java
+   * serialization. This encoder maps T into a single byte array (binary) 
field.
    *
    * T must be publicly accessible.
    *
    * @note
    *   This is extremely inefficient and should only be used as the last 
resort.
-   * @since 4.0.0
+   *
+   * @since 1.6.0
    */
-  def javaSerialization[T](clazz: Class[T]): Encoder[T] = 
javaSerialization(ClassTag[T](clazz))
+  def javaSerialization[T: ClassTag]: Encoder[T] = 
genericSerializer(JavaSerializationCodec)
 
   /**
-   * (Scala-specific) Creates an encoder that serializes objects of type T 
using Kryo. This
+   * Creates an encoder that serializes objects of type T using generic Java 
serialization. This
    * encoder maps T into a single byte array (binary) field.
    *
    * T must be publicly accessible.
    *
-   * @since 4.0.0
+   * @note
+   *   This is extremely inefficient and should only be used as the last 
resort.
+   *
+   * @since 1.6.0
    */
-  def kryo[T: ClassTag]: Encoder[T] = {
-    TransformingEncoder(implicitly[ClassTag[T]], BinaryEncoder, 
KryoSerializationCodec)
+  def javaSerialization[T](clazz: Class[T]): Encoder[T] =
+    javaSerialization(ClassTag[T](clazz))
+
+  /** Throws an exception if T is not a public class. */
+  private def validatePublicClass[T: ClassTag](): Unit = {
+    if (!Modifier.isPublic(classTag[T].runtimeClass.getModifiers)) {
+      throw 
ExecutionErrors.notPublicClassError(classTag[T].runtimeClass.getName)
+    }
+  }
+
+  /** A way to construct encoders using generic serializers. */
+  private def genericSerializer[T: ClassTag](
+      provider: () => Codec[Any, Array[Byte]]): Encoder[T] = {
+    if (classTag[T].runtimeClass.isPrimitive) {
+      throw ExecutionErrors.primitiveTypesNotSupportedError()
+    }
+
+    validatePublicClass[T]()
+
+    TransformingEncoder(classTag[T], BinaryEncoder, provider)
+  }
+
+  private[sql] def tupleEncoder[T](encoders: Encoder[_]*): Encoder[T] = {
+    
ProductEncoder.tuple(encoders.map(agnosticEncoderFor(_))).asInstanceOf[Encoder[T]]
   }
 
   /**
-   * Creates an encoder that serializes objects of type T using Kryo. This 
encoder maps T into a
-   * single byte array (binary) field.
-   *
-   * T must be publicly accessible.
+   * An encoder for 1-ary tuples.
    *
    * @since 4.0.0
    */
-  def kryo[T](clazz: Class[T]): Encoder[T] = kryo(ClassTag[T](clazz))
-
-  private def tupleEncoder[T](encoders: Encoder[_]*): Encoder[T] = {
-    
ProductEncoder.tuple(encoders.asInstanceOf[Seq[AgnosticEncoder[_]]]).asInstanceOf[Encoder[T]]
-  }
+  def tuple[T1](e1: Encoder[T1]): Encoder[(T1)] = tupleEncoder(e1)
 
   /**
    * An encoder for 2-ary tuples.
    *
-   * @since 3.5.0
+   * @since 1.6.0
    */
   def tuple[T1, T2](e1: Encoder[T1], e2: Encoder[T2]): Encoder[(T1, T2)] = 
tupleEncoder(e1, e2)
 
   /**
    * An encoder for 3-ary tuples.
    *
-   * @since 3.5.0
+   * @since 1.6.0
    */
   def tuple[T1, T2, T3](
       e1: Encoder[T1],
@@ -249,7 +278,7 @@ object Encoders {
   /**
    * An encoder for 4-ary tuples.
    *
-   * @since 3.5.0
+   * @since 1.6.0
    */
   def tuple[T1, T2, T3, T4](
       e1: Encoder[T1],
@@ -260,7 +289,7 @@ object Encoders {
   /**
    * An encoder for 5-ary tuples.
    *
-   * @since 3.5.0
+   * @since 1.6.0
    */
   def tuple[T1, T2, T3, T4, T5](
       e1: Encoder[T1],
@@ -271,49 +300,50 @@ object Encoders {
 
   /**
    * An encoder for Scala's product type (tuples, case classes, etc).
-   * @since 3.5.0
+   * @since 2.0.0
    */
   def product[T <: Product: TypeTag]: Encoder[T] = 
ScalaReflection.encoderFor[T]
 
   /**
    * An encoder for Scala's primitive int type.
-   * @since 3.5.0
+   * @since 2.0.0
    */
   def scalaInt: Encoder[Int] = PrimitiveIntEncoder
 
   /**
    * An encoder for Scala's primitive long type.
-   * @since 3.5.0
+   * @since 2.0.0
    */
   def scalaLong: Encoder[Long] = PrimitiveLongEncoder
 
   /**
    * An encoder for Scala's primitive double type.
-   * @since 3.5.0
+   * @since 2.0.0
    */
   def scalaDouble: Encoder[Double] = PrimitiveDoubleEncoder
 
   /**
    * An encoder for Scala's primitive float type.
-   * @since 3.5.0
+   * @since 2.0.0
    */
   def scalaFloat: Encoder[Float] = PrimitiveFloatEncoder
 
   /**
    * An encoder for Scala's primitive byte type.
-   * @since 3.5.0
+   * @since 2.0.0
    */
   def scalaByte: Encoder[Byte] = PrimitiveByteEncoder
 
   /**
    * An encoder for Scala's primitive short type.
-   * @since 3.5.0
+   * @since 2.0.0
    */
   def scalaShort: Encoder[Short] = PrimitiveShortEncoder
 
   /**
    * An encoder for Scala's primitive boolean type.
-   * @since 3.5.0
+   * @since 2.0.0
    */
   def scalaBoolean: Encoder[Boolean] = PrimitiveBooleanEncoder
+
 }
diff --git 
a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala
 
b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala
index a57849575549..10f734b3f84e 100644
--- 
a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala
+++ 
b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala
@@ -46,7 +46,20 @@ trait AgnosticEncoder[T] extends Encoder[T] {
   def isStruct: Boolean = false
 }
 
+/**
+ * Extract an [[AgnosticEncoder]] from an [[Encoder]].
+ */
+trait ToAgnosticEncoder[T] {
+  def encoder: AgnosticEncoder[T]
+}
+
 object AgnosticEncoders {
+  def agnosticEncoderFor[T: Encoder]: AgnosticEncoder[T] = 
implicitly[Encoder[T]] match {
+    case a: AgnosticEncoder[T] => a
+    case e: ToAgnosticEncoder[T @unchecked] => e.encoder
+    case other => throw ExecutionErrors.invalidAgnosticEncoderError(other)
+  }
+
   case class OptionEncoder[E](elementEncoder: AgnosticEncoder[E])
       extends AgnosticEncoder[Option[E]] {
     override def isPrimitive: Boolean = false
diff --git 
a/sql/api/src/main/scala/org/apache/spark/sql/errors/ExecutionErrors.scala 
b/sql/api/src/main/scala/org/apache/spark/sql/errors/ExecutionErrors.scala
index 4890ff4431fe..698a7b096e1a 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/errors/ExecutionErrors.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/errors/ExecutionErrors.scala
@@ -217,9 +217,27 @@ private[sql] trait ExecutionErrors extends 
DataTypeErrorsBase {
     new SparkRuntimeException(errorClass = "CANNOT_USE_KRYO", 
messageParameters = Map.empty)
   }
 
+  def notPublicClassError(name: String): SparkUnsupportedOperationException = {
+    new SparkUnsupportedOperationException(
+      errorClass = "_LEGACY_ERROR_TEMP_2229",
+      messageParameters = Map("name" -> name))
+  }
+
+  def primitiveTypesNotSupportedError(): SparkUnsupportedOperationException = {
+    new SparkUnsupportedOperationException(errorClass = 
"_LEGACY_ERROR_TEMP_2230")
+  }
+
   def elementsOfTupleExceedLimitError(): SparkUnsupportedOperationException = {
     new SparkUnsupportedOperationException("_LEGACY_ERROR_TEMP_2150")
   }
+
+  def invalidAgnosticEncoderError(encoder: AnyRef): Throwable = {
+    new SparkRuntimeException(
+      errorClass = "INVALID_AGNOSTIC_ENCODER",
+      messageParameters = Map(
+        "encoderType" -> encoder.getClass.getName,
+        "docroot" -> SparkBuildInfo.spark_doc_root))
+  }
 }
 
 private[sql] object ExecutionErrors extends ExecutionErrors
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala
deleted file mode 100644
index 7e040f6232fb..000000000000
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala
+++ /dev/null
@@ -1,348 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql
-
-import java.lang.reflect.Modifier
-
-import scala.reflect.{classTag, ClassTag}
-import scala.reflect.runtime.universe.TypeTag
-
-import org.apache.spark.sql.catalyst.encoders.{encoderFor, Codec, 
ExpressionEncoder, JavaSerializationCodec, KryoSerializationCodec}
-import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BinaryEncoder, 
TransformingEncoder}
-import org.apache.spark.sql.errors.QueryExecutionErrors
-import org.apache.spark.sql.types._
-
-/**
- * Methods for creating an [[Encoder]].
- *
- * @since 1.6.0
- */
-object Encoders {
-
-  /**
-   * An encoder for nullable boolean type.
-   * The Scala primitive encoder is available as [[scalaBoolean]].
-   * @since 1.6.0
-   */
-  def BOOLEAN: Encoder[java.lang.Boolean] = ExpressionEncoder()
-
-  /**
-   * An encoder for nullable byte type.
-   * The Scala primitive encoder is available as [[scalaByte]].
-   * @since 1.6.0
-   */
-  def BYTE: Encoder[java.lang.Byte] = ExpressionEncoder()
-
-  /**
-   * An encoder for nullable short type.
-   * The Scala primitive encoder is available as [[scalaShort]].
-   * @since 1.6.0
-   */
-  def SHORT: Encoder[java.lang.Short] = ExpressionEncoder()
-
-  /**
-   * An encoder for nullable int type.
-   * The Scala primitive encoder is available as [[scalaInt]].
-   * @since 1.6.0
-   */
-  def INT: Encoder[java.lang.Integer] = ExpressionEncoder()
-
-  /**
-   * An encoder for nullable long type.
-   * The Scala primitive encoder is available as [[scalaLong]].
-   * @since 1.6.0
-   */
-  def LONG: Encoder[java.lang.Long] = ExpressionEncoder()
-
-  /**
-   * An encoder for nullable float type.
-   * The Scala primitive encoder is available as [[scalaFloat]].
-   * @since 1.6.0
-   */
-  def FLOAT: Encoder[java.lang.Float] = ExpressionEncoder()
-
-  /**
-   * An encoder for nullable double type.
-   * The Scala primitive encoder is available as [[scalaDouble]].
-   * @since 1.6.0
-   */
-  def DOUBLE: Encoder[java.lang.Double] = ExpressionEncoder()
-
-  /**
-   * An encoder for nullable string type.
-   *
-   * @since 1.6.0
-   */
-  def STRING: Encoder[java.lang.String] = ExpressionEncoder()
-
-  /**
-   * An encoder for nullable decimal type.
-   *
-   * @since 1.6.0
-   */
-  def DECIMAL: Encoder[java.math.BigDecimal] = ExpressionEncoder()
-
-  /**
-   * An encoder for nullable date type.
-   *
-   * @since 1.6.0
-   */
-  def DATE: Encoder[java.sql.Date] = ExpressionEncoder()
-
-  /**
-   * Creates an encoder that serializes instances of the `java.time.LocalDate` 
class
-   * to the internal representation of nullable Catalyst's DateType.
-   *
-   * @since 3.0.0
-   */
-  def LOCALDATE: Encoder[java.time.LocalDate] = ExpressionEncoder()
-
-  /**
-   * Creates an encoder that serializes instances of the 
`java.time.LocalDateTime` class
-   * to the internal representation of nullable Catalyst's TimestampNTZType.
-   *
-   * @since 3.4.0
-   */
-  def LOCALDATETIME: Encoder[java.time.LocalDateTime] = ExpressionEncoder()
-
-  /**
-   * An encoder for nullable timestamp type.
-   *
-   * @since 1.6.0
-   */
-  def TIMESTAMP: Encoder[java.sql.Timestamp] = ExpressionEncoder()
-
-  /**
-   * Creates an encoder that serializes instances of the `java.time.Instant` 
class
-   * to the internal representation of nullable Catalyst's TimestampType.
-   *
-   * @since 3.0.0
-   */
-  def INSTANT: Encoder[java.time.Instant] = ExpressionEncoder()
-
-  /**
-   * An encoder for arrays of bytes.
-   *
-   * @since 1.6.1
-   */
-  def BINARY: Encoder[Array[Byte]] = ExpressionEncoder()
-
-  /**
-   * Creates an encoder that serializes instances of the `java.time.Duration` 
class
-   * to the internal representation of nullable Catalyst's DayTimeIntervalType.
-   *
-   * @since 3.2.0
-   */
-  def DURATION: Encoder[java.time.Duration] = ExpressionEncoder()
-
-  /**
-   * Creates an encoder that serializes instances of the `java.time.Period` 
class
-   * to the internal representation of nullable Catalyst's 
YearMonthIntervalType.
-   *
-   * @since 3.2.0
-   */
-  def PERIOD: Encoder[java.time.Period] = ExpressionEncoder()
-
-  /**
-   * Creates an encoder for Java Bean of type T.
-   *
-   * T must be publicly accessible.
-   *
-   * supported types for java bean field:
-   *  - primitive types: boolean, int, double, etc.
-   *  - boxed types: Boolean, Integer, Double, etc.
-   *  - String
-   *  - java.math.BigDecimal, java.math.BigInteger
-   *  - time related: java.sql.Date, java.sql.Timestamp, java.time.LocalDate, 
java.time.Instant
-   *  - collection types: array, java.util.List, and map
-   *  - nested java bean.
-   *
-   * @since 1.6.0
-   */
-  def bean[T](beanClass: Class[T]): Encoder[T] = 
ExpressionEncoder.javaBean(beanClass)
-
-  /**
-   * Creates a [[Row]] encoder for schema `schema`.
-   *
-   * @since 3.5.0
-   */
-  def row(schema: StructType): Encoder[Row] = ExpressionEncoder(schema)
-
-  /**
-   * (Scala-specific) Creates an encoder that serializes objects of type T 
using Kryo.
-   * This encoder maps T into a single byte array (binary) field.
-   *
-   * T must be publicly accessible.
-   *
-   * @since 1.6.0
-   */
-  def kryo[T: ClassTag]: Encoder[T] = genericSerializer(KryoSerializationCodec)
-
-  /**
-   * Creates an encoder that serializes objects of type T using Kryo.
-   * This encoder maps T into a single byte array (binary) field.
-   *
-   * T must be publicly accessible.
-   *
-   * @since 1.6.0
-   */
-  def kryo[T](clazz: Class[T]): Encoder[T] = kryo(ClassTag[T](clazz))
-
-  /**
-   * (Scala-specific) Creates an encoder that serializes objects of type T 
using generic Java
-   * serialization. This encoder maps T into a single byte array (binary) 
field.
-   *
-   * T must be publicly accessible.
-   *
-   * @note This is extremely inefficient and should only be used as the last 
resort.
-   *
-   * @since 1.6.0
-   */
-  def javaSerialization[T: ClassTag]: Encoder[T] = 
genericSerializer(JavaSerializationCodec)
-
-  /**
-   * Creates an encoder that serializes objects of type T using generic Java 
serialization.
-   * This encoder maps T into a single byte array (binary) field.
-   *
-   * T must be publicly accessible.
-   *
-   * @note This is extremely inefficient and should only be used as the last 
resort.
-   *
-   * @since 1.6.0
-   */
-  def javaSerialization[T](clazz: Class[T]): Encoder[T] = 
javaSerialization(ClassTag[T](clazz))
-
-  /** Throws an exception if T is not a public class. */
-  private def validatePublicClass[T: ClassTag](): Unit = {
-    if (!Modifier.isPublic(classTag[T].runtimeClass.getModifiers)) {
-      throw 
QueryExecutionErrors.notPublicClassError(classTag[T].runtimeClass.getName)
-    }
-  }
-
-  /** A way to construct encoders using generic serializers. */
-  private def genericSerializer[T: ClassTag](
-    provider: () => Codec[Any, Array[Byte]]): Encoder[T] = {
-    if (classTag[T].runtimeClass.isPrimitive) {
-      throw QueryExecutionErrors.primitiveTypesNotSupportedError()
-    }
-
-    validatePublicClass[T]()
-
-    ExpressionEncoder(TransformingEncoder(classTag[T], BinaryEncoder, 
provider))
-  }
-
-  /**
-   * An encoder for 2-ary tuples.
-   *
-   * @since 1.6.0
-   */
-  def tuple[T1, T2](
-    e1: Encoder[T1],
-    e2: Encoder[T2]): Encoder[(T1, T2)] = {
-    ExpressionEncoder.tuple(encoderFor(e1), encoderFor(e2))
-  }
-
-  /**
-   * An encoder for 3-ary tuples.
-   *
-   * @since 1.6.0
-   */
-  def tuple[T1, T2, T3](
-    e1: Encoder[T1],
-    e2: Encoder[T2],
-    e3: Encoder[T3]): Encoder[(T1, T2, T3)] = {
-    ExpressionEncoder.tuple(encoderFor(e1), encoderFor(e2), encoderFor(e3))
-  }
-
-  /**
-   * An encoder for 4-ary tuples.
-   *
-   * @since 1.6.0
-   */
-  def tuple[T1, T2, T3, T4](
-    e1: Encoder[T1],
-    e2: Encoder[T2],
-    e3: Encoder[T3],
-    e4: Encoder[T4]): Encoder[(T1, T2, T3, T4)] = {
-    ExpressionEncoder.tuple(encoderFor(e1), encoderFor(e2), encoderFor(e3), 
encoderFor(e4))
-  }
-
-  /**
-   * An encoder for 5-ary tuples.
-   *
-   * @since 1.6.0
-   */
-  def tuple[T1, T2, T3, T4, T5](
-    e1: Encoder[T1],
-    e2: Encoder[T2],
-    e3: Encoder[T3],
-    e4: Encoder[T4],
-    e5: Encoder[T5]): Encoder[(T1, T2, T3, T4, T5)] = {
-    ExpressionEncoder.tuple(
-      encoderFor(e1), encoderFor(e2), encoderFor(e3), encoderFor(e4), 
encoderFor(e5))
-  }
-
-  /**
-   * An encoder for Scala's product type (tuples, case classes, etc).
-   * @since 2.0.0
-   */
-  def product[T <: Product : TypeTag]: Encoder[T] = ExpressionEncoder()
-
-  /**
-   * An encoder for Scala's primitive int type.
-   * @since 2.0.0
-   */
-  def scalaInt: Encoder[Int] = ExpressionEncoder()
-
-  /**
-   * An encoder for Scala's primitive long type.
-   * @since 2.0.0
-   */
-  def scalaLong: Encoder[Long] = ExpressionEncoder()
-
-  /**
-   * An encoder for Scala's primitive double type.
-   * @since 2.0.0
-   */
-  def scalaDouble: Encoder[Double] = ExpressionEncoder()
-
-  /**
-   * An encoder for Scala's primitive float type.
-   * @since 2.0.0
-   */
-  def scalaFloat: Encoder[Float] = ExpressionEncoder()
-
-  /**
-   * An encoder for Scala's primitive byte type.
-   * @since 2.0.0
-   */
-  def scalaByte: Encoder[Byte] = ExpressionEncoder()
-
-  /**
-   * An encoder for Scala's primitive short type.
-   * @since 2.0.0
-   */
-  def scalaShort: Encoder[Short] = ExpressionEncoder()
-
-  /**
-   * An encoder for Scala's primitive boolean type.
-   * @since 2.0.0
-   */
-  def scalaBoolean: Encoder[Boolean] = ExpressionEncoder()
-
-}
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
index 8e39ae0389c2..d7d53230470d 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
@@ -70,54 +70,6 @@ object ExpressionEncoder {
      apply(JavaTypeInference.encoderFor(beanClass))
   }
 
-  /**
-   * Given a set of N encoders, constructs a new encoder that produce objects 
as items in an
-   * N-tuple.  Note that these encoders should be unresolved so that 
information about
-   * name/positional binding is preserved.
-   * When `useNullSafeDeserializer` is true, the deserialization result for a 
child will be null if
-   * the input is null. It is false by default as most deserializers handle 
null input properly and
-   * don't require an extra null check. Some of them are null-tolerant, such 
as the deserializer for
-   * `Option[T]`, and we must not set it to true in this case.
-   */
-  def tuple(
-      encoders: Seq[ExpressionEncoder[_]],
-      useNullSafeDeserializer: Boolean = false): ExpressionEncoder[_] = {
-    val tupleEncoder = AgnosticEncoders.ProductEncoder.tuple(
-      encoders.map(_.encoder),
-      useNullSafeDeserializer)
-    ExpressionEncoder(tupleEncoder)
-  }
-
-  // Tuple1
-  def tuple[T](e: ExpressionEncoder[T]): ExpressionEncoder[Tuple1[T]] =
-    tuple(Seq(e)).asInstanceOf[ExpressionEncoder[Tuple1[T]]]
-
-  def tuple[T1, T2](
-      e1: ExpressionEncoder[T1],
-      e2: ExpressionEncoder[T2]): ExpressionEncoder[(T1, T2)] =
-    tuple(Seq(e1, e2)).asInstanceOf[ExpressionEncoder[(T1, T2)]]
-
-  def tuple[T1, T2, T3](
-      e1: ExpressionEncoder[T1],
-      e2: ExpressionEncoder[T2],
-      e3: ExpressionEncoder[T3]): ExpressionEncoder[(T1, T2, T3)] =
-    tuple(Seq(e1, e2, e3)).asInstanceOf[ExpressionEncoder[(T1, T2, T3)]]
-
-  def tuple[T1, T2, T3, T4](
-      e1: ExpressionEncoder[T1],
-      e2: ExpressionEncoder[T2],
-      e3: ExpressionEncoder[T3],
-      e4: ExpressionEncoder[T4]): ExpressionEncoder[(T1, T2, T3, T4)] =
-    tuple(Seq(e1, e2, e3, e4)).asInstanceOf[ExpressionEncoder[(T1, T2, T3, 
T4)]]
-
-  def tuple[T1, T2, T3, T4, T5](
-      e1: ExpressionEncoder[T1],
-      e2: ExpressionEncoder[T2],
-      e3: ExpressionEncoder[T3],
-      e4: ExpressionEncoder[T4],
-      e5: ExpressionEncoder[T5]): ExpressionEncoder[(T1, T2, T3, T4, T5)] =
-    tuple(Seq(e1, e2, e3, e4, e5)).asInstanceOf[ExpressionEncoder[(T1, T2, T3, 
T4, T5)]]
-
   private val anyObjectType = ObjectType(classOf[Any])
 
   /**
@@ -189,7 +141,8 @@ case class ExpressionEncoder[T](
     encoder: AgnosticEncoder[T],
     objSerializer: Expression,
     objDeserializer: Expression)
-  extends Encoder[T] {
+  extends Encoder[T]
+  with ToAgnosticEncoder[T] {
 
   override def clsTag: ClassTag[T] = encoder.clsTag
 
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
index 2ab86a5c5f03..4bc071155012 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
@@ -1876,17 +1876,6 @@ private[sql] object QueryExecutionErrors extends 
QueryErrorsBase with ExecutionE
       cause = null)
   }
 
-  def notPublicClassError(name: String): SparkUnsupportedOperationException = {
-    new SparkUnsupportedOperationException(
-      errorClass = "_LEGACY_ERROR_TEMP_2229",
-      messageParameters = Map(
-        "name" -> name))
-  }
-
-  def primitiveTypesNotSupportedError(): SparkUnsupportedOperationException = {
-    new SparkUnsupportedOperationException(errorClass = 
"_LEGACY_ERROR_TEMP_2230")
-  }
-
   def onlySupportDataSourcesProvidingFileFormatError(providingClass: String): 
Throwable = {
     new SparkException(
       errorClass = "_LEGACY_ERROR_TEMP_2233",
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala
index 28796db7c02e..35a27f41da80 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.encoders
 import scala.reflect.runtime.universe.TypeTag
 
 import org.apache.spark.SparkRuntimeException
-import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.{AnalysisException, Encoders}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.expressions.{Attribute, 
AttributeReference}
@@ -71,9 +71,9 @@ class EncoderResolutionSuite extends PlanTest {
   }
 
   test("real type doesn't match encoder schema but they are compatible: tupled 
encoder") {
-    val encoder = ExpressionEncoder.tuple(
-      ExpressionEncoder[StringLongClass](),
-      ExpressionEncoder[Long]())
+    val encoder = encoderFor(Encoders.tuple(
+      Encoders.product[StringLongClass],
+      Encoders.scalaLong))
     val attrs = Seq($"a".struct($"a".string, $"b".byte), $"b".int)
     testFromRow(encoder, attrs, InternalRow(InternalRow(str, 1.toByte), 2))
   }
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
index 0c0c7f12f176..3b5cbed2cc52 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
@@ -321,29 +321,29 @@ class ExpressionEncoderSuite extends 
CodegenInterpretedPlanTest with AnalysisTes
   encodeDecodeTest(
     1 -> 10L,
     "tuple with 2 flat encoders")(
-    ExpressionEncoder.tuple(ExpressionEncoder[Int](), 
ExpressionEncoder[Long]()))
+    encoderFor(Encoders.tuple(Encoders.scalaInt, Encoders.scalaLong)))
 
   encodeDecodeTest(
     (PrimitiveData(1, 1, 1, 1, 1, 1, true), (3, 30L)),
     "tuple with 2 product encoders")(
-    ExpressionEncoder.tuple(ExpressionEncoder[PrimitiveData](), 
ExpressionEncoder[(Int, Long)]()))
+    encoderFor(Encoders.tuple(Encoders.product[PrimitiveData], 
Encoders.product[(Int, Long)])))
 
   encodeDecodeTest(
     (PrimitiveData(1, 1, 1, 1, 1, 1, true), 3),
     "tuple with flat encoder and product encoder")(
-    ExpressionEncoder.tuple(ExpressionEncoder[PrimitiveData](), 
ExpressionEncoder[Int]()))
+    encoderFor(Encoders.tuple(Encoders.product[PrimitiveData], 
Encoders.scalaInt)))
 
   encodeDecodeTest(
     (3, PrimitiveData(1, 1, 1, 1, 1, 1, true)),
     "tuple with product encoder and flat encoder")(
-    ExpressionEncoder.tuple(ExpressionEncoder[Int](), 
ExpressionEncoder[PrimitiveData]()))
+    encoderFor(Encoders.tuple(Encoders.scalaInt, 
Encoders.product[PrimitiveData])))
 
   encodeDecodeTest(
     (1, (10, 100L)),
     "nested tuple encoder") {
-    val intEnc = ExpressionEncoder[Int]()
-    val longEnc = ExpressionEncoder[Long]()
-    ExpressionEncoder.tuple(intEnc, ExpressionEncoder.tuple(intEnc, longEnc))
+    val intEnc = Encoders.scalaInt
+    val longEnc = Encoders.scalaLong
+    encoderFor(Encoders.tuple(intEnc, Encoders.tuple(intEnc, longEnc)))
   }
 
   // test for value classes
@@ -468,9 +468,8 @@ class ExpressionEncoderSuite extends 
CodegenInterpretedPlanTest with AnalysisTes
 
     // test for tupled encoders
     {
-      val schema = ExpressionEncoder.tuple(
-        ExpressionEncoder[Int](),
-        ExpressionEncoder[(String, Int)]()).schema
+      val encoder = encoderFor(Encoders.tuple(Encoders.scalaInt, 
Encoders.product[(String, Int)]))
+      val schema = encoder.schema
       assert(schema(0).nullable === false)
       assert(schema(1).nullable)
       assert(schema(1).dataType.asInstanceOf[StructType](0).nullable)
@@ -513,11 +512,11 @@ class ExpressionEncoderSuite extends 
CodegenInterpretedPlanTest with AnalysisTes
   }
 
   test("throw exception for tuples with more than 22 elements") {
-    val encoders = (0 to 22).map(_ => 
Encoders.scalaInt.asInstanceOf[ExpressionEncoder[_]])
+    val encoders = (0 to 22).map(_ => Encoders.scalaInt)
 
     checkError(
       exception = intercept[SparkUnsupportedOperationException] {
-        ExpressionEncoder.tuple(encoders)
+        Encoders.tupleEncoder(encoders: _*)
       },
       condition = "_LEGACY_ERROR_TEMP_2150",
       parameters = Map.empty)
diff --git 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index bb6d52308c19..33c9edb1cd21 100644
--- 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -47,7 +47,7 @@ import org.apache.spark.sql.{Dataset, Encoders, 
ForeachWriter, Observation, Rela
 import org.apache.spark.sql.avro.{AvroDataToCatalyst, CatalystDataToAvro}
 import org.apache.spark.sql.catalyst.{expressions, AliasIdentifier, 
FunctionIdentifier, QueryPlanningTracker}
 import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, 
GlobalTempView, LocalTempView, MultiAlias, NameParameterizedQuery, 
PosParameterizedQuery, UnresolvedAlias, UnresolvedAttribute, 
UnresolvedDataFrameStar, UnresolvedDeserializer, UnresolvedExtractValue, 
UnresolvedFunction, UnresolvedRegex, UnresolvedRelation, UnresolvedStar, 
UnresolvedTranspose}
-import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, 
ExpressionEncoder, RowEncoder}
+import org.apache.spark.sql.catalyst.encoders.{encoderFor, AgnosticEncoder, 
ExpressionEncoder, RowEncoder}
 import 
org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.UnboundRowEncoder
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
@@ -2318,16 +2318,17 @@ class SparkConnectPlanner(
     if (fun.getArgumentsCount != 1) {
       throw InvalidPlanInput("reduce requires single child expression")
     }
-    val udf = fun.getArgumentsList.asScala.map(transformExpression) match {
-      case collection.Seq(f: ScalaUDF) =>
-        f
+    val udf = fun.getArgumentsList.asScala match {
+      case collection.Seq(e)
+          if e.hasCommonInlineUserDefinedFunction &&
+            e.getCommonInlineUserDefinedFunction.hasScalarScalaUdf =>
+        unpackUdf(e.getCommonInlineUserDefinedFunction)
       case other =>
         throw InvalidPlanInput(s"reduce should carry a scalar scala udf, but 
got $other")
     }
-    assert(udf.outputEncoder.isDefined)
-    val tEncoder = udf.outputEncoder.get // (T, T) => T
-    val reduce = ReduceAggregator(udf.function)(tEncoder).toColumn.expr
-    TypedAggUtils.withInputType(reduce, tEncoder, dataAttributes)
+    val encoder = udf.outputEncoder
+    val reduce = ReduceAggregator(udf.function)(encoder).toColumn.expr
+    TypedAggUtils.withInputType(reduce, encoderFor(encoder), dataAttributes)
   }
 
   private def transformExpressionWithTypedReduceExpression(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 6e5dcc24e29d..c147b6a56e02 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -42,6 +42,7 @@ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, 
InternalRow, Query
 import org.apache.spark.sql.catalyst.analysis._
 import org.apache.spark.sql.catalyst.catalog.HiveTableRelation
 import org.apache.spark.sql.catalyst.encoders._
+import 
org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{agnosticEncoderFor, 
ProductEncoder, StructEncoder}
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.json.{JacksonGenerator, JSONOptions}
 import org.apache.spark.sql.catalyst.parser.{ParseException, ParserUtils}
@@ -78,13 +79,14 @@ private[sql] object Dataset {
   val DATASET_ID_TAG = TreeNodeTag[HashSet[Long]]("dataset_id")
 
   def apply[T: Encoder](sparkSession: SparkSession, logicalPlan: LogicalPlan): 
Dataset[T] = {
-    val dataset = new Dataset(sparkSession, logicalPlan, 
implicitly[Encoder[T]])
+    val encoder = implicitly[Encoder[T]]
+    val dataset = new Dataset(sparkSession, logicalPlan, encoder)
     // Eagerly bind the encoder so we verify that the encoder matches the 
underlying
     // schema. The user will get an error if this is not the case.
     // optimization: it is guaranteed that [[InternalRow]] can be converted to 
[[Row]] so
     // do not do this check in that case. this check can be expensive since it 
requires running
     // the whole [[Analyzer]] to resolve the deserializer
-    if (dataset.exprEnc.clsTag.runtimeClass != classOf[Row]) {
+    if (dataset.encoder.clsTag.runtimeClass != classOf[Row]) {
       dataset.resolvedEnc
     }
     dataset
@@ -94,7 +96,7 @@ private[sql] object Dataset {
     sparkSession.withActive {
       val qe = sparkSession.sessionState.executePlan(logicalPlan)
       qe.assertAnalyzed()
-      new Dataset[Row](qe, ExpressionEncoder(qe.analyzed.schema))
+      new Dataset[Row](qe, RowEncoder.encoderFor(qe.analyzed.schema))
   }
 
   def ofRows(
@@ -105,7 +107,7 @@ private[sql] object Dataset {
       val qe = new QueryExecution(
         sparkSession, logicalPlan, shuffleCleanupMode = shuffleCleanupMode)
       qe.assertAnalyzed()
-      new Dataset[Row](qe, ExpressionEncoder(qe.analyzed.schema))
+      new Dataset[Row](qe, RowEncoder.encoderFor(qe.analyzed.schema))
     }
 
   /** A variant of ofRows that allows passing in a tracker so we can track 
query parsing time. */
@@ -118,7 +120,7 @@ private[sql] object Dataset {
     val qe = new QueryExecution(
       sparkSession, logicalPlan, tracker, shuffleCleanupMode = 
shuffleCleanupMode)
     qe.assertAnalyzed()
-    new Dataset[Row](qe, ExpressionEncoder(qe.analyzed.schema))
+    new Dataset[Row](qe, RowEncoder.encoderFor(qe.analyzed.schema))
   }
 }
 
@@ -252,12 +254,17 @@ class Dataset[T] private[sql](
   }
 
   /**
-   * Currently [[ExpressionEncoder]] is the only implementation of 
[[Encoder]], here we turn the
-   * passed in encoder to [[ExpressionEncoder]] explicitly, and mark it 
implicit so that we can use
-   * it when constructing new Dataset objects that have the same object type 
(that will be
-   * possibly resolved to a different schema).
+   * Expose the encoder as implicit so it can be used to construct new Dataset 
objects that have
+   * the same external type.
    */
-  private[sql] implicit val exprEnc: ExpressionEncoder[T] = encoderFor(encoder)
+  private implicit def encoderImpl: Encoder[T] = encoder
+
+  /**
+   * The actual [[ExpressionEncoder]] used by the dataset. This and its 
resolved counterpart should
+   * only be used for actual (de)serialization, the binding of Aggregator 
inputs, and in the rare
+   * cases where a plan needs to be constructed with an ExpressionEncoder.
+   */
+  private[sql] lazy val exprEnc: ExpressionEncoder[T] = encoderFor(encoder)
 
   // The resolved `ExpressionEncoder` which can be used to turn rows to 
objects of type T, after
   // collecting rows to the driver side.
@@ -265,7 +272,7 @@ class Dataset[T] private[sql](
     exprEnc.resolveAndBind(logicalPlan.output, 
sparkSession.sessionState.analyzer)
   }
 
-  private implicit def classTag: ClassTag[T] = exprEnc.clsTag
+  private implicit def classTag: ClassTag[T] = encoder.clsTag
 
   // sqlContext must be val because a stable identifier is expected when you 
import implicits
   @transient lazy val sqlContext: SQLContext = sparkSession.sqlContext
@@ -476,7 +483,7 @@ class Dataset[T] private[sql](
   /** @inheritdoc */
   // This is declared with parentheses to prevent the Scala compiler from 
treating
   // `ds.toDF("1")` as invoking this toDF and then apply on the returned 
DataFrame.
-  def toDF(): DataFrame = new Dataset[Row](queryExecution, 
ExpressionEncoder(schema))
+  def toDF(): DataFrame = new Dataset[Row](queryExecution, 
RowEncoder.encoderFor(schema))
 
   /** @inheritdoc */
   def as[U : Encoder]: Dataset[U] = Dataset[U](sparkSession, logicalPlan)
@@ -671,17 +678,17 @@ class Dataset[T] private[sql](
         Some(condition.expr),
         JoinHint.NONE)).analyzed.asInstanceOf[Join]
 
-    implicit val tuple2Encoder: Encoder[(T, U)] =
-      ExpressionEncoder
-        .tuple(Seq(this.exprEnc, other.exprEnc), useNullSafeDeserializer = 
true)
-        .asInstanceOf[Encoder[(T, U)]]
-
-    withTypedPlan(JoinWith.typedJoinWith(
+    val leftEncoder = agnosticEncoderFor(encoder)
+    val rightEncoder = agnosticEncoderFor(other.encoder)
+    val joinEncoder = ProductEncoder.tuple(Seq(leftEncoder, rightEncoder), 
elementsCanBeNull = true)
+      .asInstanceOf[Encoder[(T, U)]]
+    val joinWith = JoinWith.typedJoinWith(
       joined,
       sparkSession.sessionState.conf.dataFrameSelfJoinAutoResolveAmbiguity,
       sparkSession.sessionState.analyzer.resolver,
-      this.exprEnc.isSerializedAsStructForTopLevel,
-      other.exprEnc.isSerializedAsStructForTopLevel))
+      leftEncoder.isStruct,
+      rightEncoder.isStruct)
+    new Dataset(sparkSession, joinWith, joinEncoder)
   }
 
   // TODO(SPARK-22947): Fix the DataFrame API.
@@ -826,24 +833,29 @@ class Dataset[T] private[sql](
 
   /** @inheritdoc */
   def select[U1](c1: TypedColumn[T, U1]): Dataset[U1] = {
-    implicit val encoder: ExpressionEncoder[U1] = encoderFor(c1.encoder)
+    val encoder = agnosticEncoderFor(c1.encoder)
     val tc1 = withInputType(c1.named, exprEnc, logicalPlan.output)
     val project = Project(tc1 :: Nil, logicalPlan)
 
-    if (!encoder.isSerializedAsStructForTopLevel) {
-      new Dataset[U1](sparkSession, project, encoder)
-    } else {
-      // Flattens inner fields of U1
-      new Dataset[Tuple1[U1]](sparkSession, project, 
ExpressionEncoder.tuple(encoder)).map(_._1)
+    val plan = encoder match {
+      case se: StructEncoder[U1] =>
+        // Flatten the result.
+        val attribute = GetColumnByOrdinal(0, se.dataType)
+        val projectList = se.fields.zipWithIndex.map {
+          case (field, index) =>
+            Alias(GetStructField(attribute, index, None), field.name)()
+        }
+        Project(projectList, project)
+      case _ => project
     }
+    new Dataset[U1](sparkSession, plan, encoder)
   }
 
   /** @inheritdoc */
   protected def selectUntyped(columns: TypedColumn[_, _]*): Dataset[_] = {
-    val encoders = columns.map(c => encoderFor(c.encoder))
+    val encoders = columns.map(c => agnosticEncoderFor(c.encoder))
     val namedColumns = columns.map(c => withInputType(c.named, exprEnc, 
logicalPlan.output))
-    val execution = new QueryExecution(sparkSession, Project(namedColumns, 
logicalPlan))
-    new Dataset(execution, ExpressionEncoder.tuple(encoders))
+    new Dataset(sparkSession, Project(namedColumns, logicalPlan), 
ProductEncoder.tuple(encoders))
   }
 
   /** @inheritdoc */
@@ -912,8 +924,8 @@ class Dataset[T] private[sql](
     val executed = sparkSession.sessionState.executePlan(withGroupingKey)
 
     new KeyValueGroupedDataset(
-      encoderFor[K],
-      encoderFor[T],
+      implicitly[Encoder[K]],
+      encoder,
       executed,
       logicalPlan.output,
       withGroupingKey.newColumns)
@@ -1387,7 +1399,11 @@ class Dataset[T] private[sql](
       packageNames: Array[Byte],
       broadcastVars: Array[Broadcast[Object]],
       schema: StructType): DataFrame = {
-    val rowEncoder = encoder.asInstanceOf[ExpressionEncoder[Row]]
+    val rowEncoder: ExpressionEncoder[Row] = if (isUnTyped) {
+      exprEnc.asInstanceOf[ExpressionEncoder[Row]]
+    } else {
+      ExpressionEncoder(schema)
+    }
     Dataset.ofRows(
       sparkSession,
       MapPartitionsInR(func, packageNames, broadcastVars, schema, rowEncoder, 
logicalPlan))
@@ -2237,7 +2253,7 @@ class Dataset[T] private[sql](
 
   /** A convenient function to wrap a set based logical plan and produce a 
Dataset. */
   @inline private def withSetOperator[U : Encoder](logicalPlan: LogicalPlan): 
Dataset[U] = {
-    if (classTag.runtimeClass.isAssignableFrom(classOf[Row])) {
+    if (isUnTyped) {
       // Set operators widen types (change the schema), so we cannot reuse the 
row encoder.
       Dataset.ofRows(sparkSession, logicalPlan).asInstanceOf[Dataset[U]]
     } else {
@@ -2245,6 +2261,8 @@ class Dataset[T] private[sql](
     }
   }
 
+  private def isUnTyped: Boolean = 
classTag.runtimeClass.isAssignableFrom(classOf[Row])
+
   /** Returns a optimized plan for CommandResult, convert to `LocalRelation`. 
*/
   private def commandResultOptimized: Dataset[T] = {
     logicalPlan match {
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
index 1ebdd57f1962..fcad1b721eac 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
@@ -19,7 +19,8 @@ package org.apache.spark.sql
 
 import org.apache.spark.api.java.function._
 import org.apache.spark.sql.catalyst.analysis.{EliminateEventTimeWatermark, 
UnresolvedAttribute}
-import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder}
+import 
org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{agnosticEncoderFor, 
ProductEncoder}
+import org.apache.spark.sql.catalyst.encoders.encoderFor
 import org.apache.spark.sql.catalyst.expressions.Attribute
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.execution.QueryExecution
@@ -43,9 +44,8 @@ class KeyValueGroupedDataset[K, V] private[sql](
   extends api.KeyValueGroupedDataset[K, V, Dataset] {
   type KVDS[KY, VL] = KeyValueGroupedDataset[KY, VL]
 
-  // Similar to [[Dataset]], we turn the passed in encoder to 
`ExpressionEncoder` explicitly.
-  private implicit val kExprEnc: ExpressionEncoder[K] = encoderFor(kEncoder)
-  private implicit val vExprEnc: ExpressionEncoder[V] = encoderFor(vEncoder)
+  private implicit def kEncoderImpl: Encoder[K] = kEncoder
+  private implicit def vEncoderImpl: Encoder[V] = vEncoder
 
   private def logicalPlan = queryExecution.analyzed
   private def sparkSession = queryExecution.sparkSession
@@ -54,8 +54,8 @@ class KeyValueGroupedDataset[K, V] private[sql](
   /** @inheritdoc */
   def keyAs[L : Encoder]: KeyValueGroupedDataset[L, V] =
     new KeyValueGroupedDataset(
-      encoderFor[L],
-      vExprEnc,
+      implicitly[Encoder[L]],
+      vEncoder,
       queryExecution,
       dataAttributes,
       groupingAttributes)
@@ -67,8 +67,8 @@ class KeyValueGroupedDataset[K, V] private[sql](
     val executed = sparkSession.sessionState.executePlan(projected)
 
     new KeyValueGroupedDataset(
-      encoderFor[K],
-      encoderFor[W],
+      kEncoder,
+      implicitly[Encoder[W]],
       executed,
       withNewData.newColumns,
       groupingAttributes)
@@ -297,20 +297,21 @@ class KeyValueGroupedDataset[K, V] private[sql](
 
   /** @inheritdoc */
   def reduceGroups(f: (V, V) => V): Dataset[(K, V)] = {
-    val vEncoder = encoderFor[V]
     val aggregator: TypedColumn[V, V] = new 
ReduceAggregator[V](f)(vEncoder).toColumn
     agg(aggregator)
   }
 
   /** @inheritdoc */
   protected def aggUntyped(columns: TypedColumn[_, _]*): Dataset[_] = {
-    val encoders = columns.map(c => encoderFor(c.encoder))
-    val namedColumns = columns.map(c => withInputType(c.named, vExprEnc, 
dataAttributes))
-    val keyColumn = aggKeyColumn(kExprEnc, groupingAttributes)
+    val keyAgEncoder = agnosticEncoderFor(kEncoder)
+    val valueExprEncoder = encoderFor(vEncoder)
+    val encoders = columns.map(c => agnosticEncoderFor(c.encoder))
+    val namedColumns = columns.map { c =>
+      withInputType(c.named, valueExprEncoder, dataAttributes)
+    }
+    val keyColumn = aggKeyColumn(keyAgEncoder, groupingAttributes)
     val aggregate = Aggregate(groupingAttributes, keyColumn +: namedColumns, 
logicalPlan)
-    val execution = new QueryExecution(sparkSession, aggregate)
-
-    new Dataset(execution, ExpressionEncoder.tuple(kExprEnc +: encoders))
+    new Dataset(sparkSession, aggregate, ProductEncoder.tuple(keyAgEncoder +: 
encoders))
   }
 
   /** @inheritdoc */
@@ -319,7 +320,7 @@ class KeyValueGroupedDataset[K, V] private[sql](
       thisSortExprs: Column*)(
       otherSortExprs: Column*)(
       f: (K, Iterator[V], Iterator[U]) => IterableOnce[R]): Dataset[R] = {
-    implicit val uEncoder = other.vExprEnc
+    implicit val uEncoder = other.vEncoderImpl
     Dataset[R](
       sparkSession,
       CoGroup(
@@ -336,10 +337,10 @@ class KeyValueGroupedDataset[K, V] private[sql](
 
   override def toString: String = {
     val builder = new StringBuilder
-    val kFields = kExprEnc.schema.map { f =>
+    val kFields = kEncoder.schema.map { f =>
       s"${f.name}: ${f.dataType.simpleString(2)}"
     }
-    val vFields = vExprEnc.schema.map { f =>
+    val vFields = vEncoder.schema.map { f =>
       s"${f.name}: ${f.dataType.simpleString(2)}"
     }
     builder.append("KeyValueGroupedDataset: [key: [")
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
index 4e4454018e81..da4609135fd6 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
@@ -22,7 +22,6 @@ import org.apache.spark.annotation.Stable
 import org.apache.spark.api.python.PythonEvalType
 import org.apache.spark.broadcast.Broadcast
 import org.apache.spark.sql.catalyst.analysis.UnresolvedAlias
-import org.apache.spark.sql.catalyst.encoders.encoderFor
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate._
 import org.apache.spark.sql.catalyst.plans.logical._
@@ -119,15 +118,12 @@ class RelationalGroupedDataset protected[sql](
 
   /** @inheritdoc */
   def as[K: Encoder, T: Encoder]: KeyValueGroupedDataset[K, T] = {
-    val keyEncoder = encoderFor[K]
-    val valueEncoder = encoderFor[T]
-
     val (qe, groupingAttributes) =
       handleGroupingExpression(df.logicalPlan, df.sparkSession, groupingExprs)
 
     new KeyValueGroupedDataset(
-      keyEncoder,
-      valueEncoder,
+      implicitly[Encoder[K]],
+      implicitly[Encoder[T]],
       qe,
       df.logicalPlan.output,
       groupingAttributes)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
index 3832d7304407..09d9915022a6 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
@@ -504,7 +504,7 @@ case class ScalaAggregator[IN, BUF, OUT](
   private[this] lazy val inputDeserializer = inputEncoder.createDeserializer()
   private[this] lazy val bufferSerializer = bufferEncoder.createSerializer()
   private[this] lazy val bufferDeserializer = 
bufferEncoder.createDeserializer()
-  private[this] lazy val outputEncoder = 
agg.outputEncoder.asInstanceOf[ExpressionEncoder[OUT]]
+  private[this] lazy val outputEncoder = encoderFor(agg.outputEncoder)
   private[this] lazy val outputSerializer = outputEncoder.createSerializer()
 
   def dataType: DataType = outputEncoder.objSerializer.dataType
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTextSocketSource.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTextSocketSource.scala
index 420c3e3be16d..273ffa6aefb7 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTextSocketSource.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTextSocketSource.scala
@@ -32,8 +32,9 @@ import org.apache.spark.SparkEnv
 import org.apache.spark.internal.{Logging, MDC}
 import org.apache.spark.internal.LogKeys.{HOST, PORT}
 import org.apache.spark.rpc.RpcEndpointRef
+import org.apache.spark.sql.Encoders
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.catalyst.encoders.encoderFor
 import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow}
 import org.apache.spark.sql.connector.read.InputPartition
 import 
org.apache.spark.sql.connector.read.streaming.{ContinuousPartitionReader, 
ContinuousPartitionReaderFactory, ContinuousStream, Offset, PartitionOffset}
@@ -57,8 +58,7 @@ class TextSocketContinuousStream(
 
   implicit val defaultFormats: DefaultFormats = DefaultFormats
 
-  private val encoder = ExpressionEncoder.tuple(ExpressionEncoder[String](),
-    ExpressionEncoder[Timestamp]())
+  private val encoder = encoderFor(Encoders.tuple(Encoders.STRING, 
Encoders.TIMESTAMP))
 
   @GuardedBy("this")
   private var socket: Socket = _
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/expressions/ReduceAggregator.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/expressions/ReduceAggregator.scala
index fd3df372a2d5..192b5bf65c4c 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/expressions/ReduceAggregator.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/expressions/ReduceAggregator.scala
@@ -19,7 +19,8 @@ package org.apache.spark.sql.expressions
 
 import org.apache.spark.SparkException
 import org.apache.spark.sql.Encoder
-import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
+import 
org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{PrimitiveBooleanEncoder,
 ProductEncoder}
 
 /**
  * An aggregator that uses a single associative and commutative reduce 
function. This reduce
@@ -46,10 +47,10 @@ private[sql] class ReduceAggregator[T: Encoder](func: (T, 
T) => T)
 
   override def zero: (Boolean, T) = (false, _zero.asInstanceOf[T])
 
-  override def bufferEncoder: Encoder[(Boolean, T)] =
-    ExpressionEncoder.tuple(
-      ExpressionEncoder[Boolean](),
-      encoder.asInstanceOf[ExpressionEncoder[T]])
+  override def bufferEncoder: Encoder[(Boolean, T)] = {
+    ProductEncoder.tuple(Seq(PrimitiveBooleanEncoder, 
encoder.asInstanceOf[AgnosticEncoder[T]]))
+      .asInstanceOf[Encoder[(Boolean, T)]]
+  }
 
   override def outputEncoder: Encoder[T] = encoder
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/internal/TypedAggUtils.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/internal/TypedAggUtils.scala
index b6340a35e770..23ceb8135fa8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/TypedAggUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/TypedAggUtils.scala
@@ -17,7 +17,9 @@
 
 package org.apache.spark.sql.internal
 
+import org.apache.spark.sql.Encoder
 import org.apache.spark.sql.catalyst.analysis.UnresolvedDeserializer
+import 
org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.agnosticEncoderFor
 import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression
@@ -25,10 +27,10 @@ import 
org.apache.spark.sql.execution.aggregate.TypedAggregateExpression
 private[sql] object TypedAggUtils {
 
   def aggKeyColumn[A](
-      encoder: ExpressionEncoder[A],
+      encoder: Encoder[A],
       groupingAttributes: Seq[Attribute]): NamedExpression = {
-    if (!encoder.isSerializedAsStructForTopLevel) {
-      assert(groupingAttributes.length == 1)
+    val agnosticEncoder = agnosticEncoderFor(encoder)
+    if (!agnosticEncoder.isStruct) {
       if (SQLConf.get.nameNonStructGroupingKeyAsValue) {
         groupingAttributes.head
       } else {
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelperSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelperSuite.scala
index ea6fd8ab312c..2456999b4382 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelperSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelperSuite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.streaming.state
 import java.util.concurrent.atomic.AtomicInteger
 
 import org.apache.spark.sql.Encoder
-import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder}
 import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, 
UnsafeProjection, UnsafeRow}
 import org.apache.spark.sql.execution.streaming.GroupStateImpl._
 import org.apache.spark.sql.streaming.StreamTest
@@ -201,7 +201,7 @@ class FlatMapGroupsWithStateExecHelperSuite extends 
StreamTest {
 
   private def newStateManager[T: Encoder](version: Int, withTimestamp: 
Boolean): StateManager = {
     FlatMapGroupsWithStateExecHelper.createStateManager(
-      implicitly[Encoder[T]].asInstanceOf[ExpressionEncoder[Any]],
+      encoderFor[T].asInstanceOf[ExpressionEncoder[Any]],
       withTimestamp,
       version)
   }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ListStateSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ListStateSuite.scala
index add12f7e1535..e9300464af8d 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ListStateSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ListStateSuite.scala
@@ -22,7 +22,7 @@ import java.util.UUID
 
 import org.apache.spark.{SparkIllegalArgumentException, 
SparkUnsupportedOperationException}
 import org.apache.spark.sql.Encoders
-import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder}
 import org.apache.spark.sql.execution.streaming.{ImplicitGroupingKeyTracker, 
ListStateImplWithTTL, StatefulProcessorHandleImpl}
 import org.apache.spark.sql.streaming.{ListState, TimeMode, TTLConfig, 
ValueState}
 
@@ -38,7 +38,7 @@ class ListStateSuite extends StateVariableSuiteBase {
     tryWithProviderResource(newStoreProviderWithStateVariable(true)) { 
provider =>
       val store = provider.getStore(0)
       val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(),
-        Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeMode.None())
+        stringEncoder, TimeMode.None())
 
       val listState: ListState[Long] = handle.getListState[Long]("listState", 
Encoders.scalaLong)
 
@@ -71,7 +71,7 @@ class ListStateSuite extends StateVariableSuiteBase {
     tryWithProviderResource(newStoreProviderWithStateVariable(true)) { 
provider =>
       val store = provider.getStore(0)
       val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(),
-        Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeMode.None())
+        stringEncoder, TimeMode.None())
 
       val testState: ListState[Long] = handle.getListState[Long]("testState", 
Encoders.scalaLong)
       ImplicitGroupingKeyTracker.setImplicitKey("test_key")
@@ -99,7 +99,7 @@ class ListStateSuite extends StateVariableSuiteBase {
     tryWithProviderResource(newStoreProviderWithStateVariable(true)) { 
provider =>
       val store = provider.getStore(0)
       val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(),
-        Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeMode.None())
+        stringEncoder, TimeMode.None())
 
       val testState1: ListState[Long] = 
handle.getListState[Long]("testState1", Encoders.scalaLong)
       val testState2: ListState[Long] = 
handle.getListState[Long]("testState2", Encoders.scalaLong)
@@ -137,7 +137,7 @@ class ListStateSuite extends StateVariableSuiteBase {
     tryWithProviderResource(newStoreProviderWithStateVariable(true)) { 
provider =>
       val store = provider.getStore(0)
       val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(),
-        Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeMode.None())
+        stringEncoder, TimeMode.None())
 
       val listState1: ListState[Long] = 
handle.getListState[Long]("listState1", Encoders.scalaLong)
       val listState2: ListState[Long] = 
handle.getListState[Long]("listState2", Encoders.scalaLong)
@@ -167,7 +167,7 @@ class ListStateSuite extends StateVariableSuiteBase {
       val store = provider.getStore(0)
       val timestampMs = 10
       val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(),
-        Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]],
+        stringEncoder,
         TimeMode.ProcessingTime(), batchTimestampMs = Some(timestampMs))
 
       val ttlConfig = TTLConfig(ttlDuration = Duration.ofMinutes(1))
@@ -187,7 +187,7 @@ class ListStateSuite extends StateVariableSuiteBase {
 
       // increment batchProcessingTime, or watermark and ensure expired value 
is not returned
       val nextBatchHandle = new StatefulProcessorHandleImpl(store, 
UUID.randomUUID(),
-        Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]],
+        stringEncoder,
         TimeMode.ProcessingTime(), batchTimestampMs = Some(ttlExpirationMs))
 
       val nextBatchTestState: ListStateImplWithTTL[String] =
@@ -223,7 +223,7 @@ class ListStateSuite extends StateVariableSuiteBase {
       val store = provider.getStore(0)
       val batchTimestampMs = 10
       val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(),
-        Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]],
+        stringEncoder,
         TimeMode.ProcessingTime(), batchTimestampMs = Some(batchTimestampMs))
 
       Seq(null, Duration.ZERO, Duration.ofMinutes(-1)).foreach { ttlDuration =>
@@ -250,7 +250,7 @@ class ListStateSuite extends StateVariableSuiteBase {
       val store = provider.getStore(0)
       val timestampMs = 10
       val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(),
-        
Encoders.bean(classOf[POJOTestClass]).asInstanceOf[ExpressionEncoder[Any]],
+        
encoderFor(Encoders.bean(classOf[POJOTestClass])).asInstanceOf[ExpressionEncoder[Any]],
         TimeMode.ProcessingTime(), batchTimestampMs = Some(timestampMs))
 
       val ttlConfig = TTLConfig(ttlDuration = Duration.ofMinutes(1))
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MapStateSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MapStateSuite.scala
index 9c322b201da8..b067d589de90 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MapStateSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MapStateSuite.scala
@@ -22,7 +22,6 @@ import java.util.UUID
 
 import org.apache.spark.SparkUnsupportedOperationException
 import org.apache.spark.sql.Encoders
-import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
 import org.apache.spark.sql.execution.streaming.{ImplicitGroupingKeyTracker, 
MapStateImplWithTTL, StatefulProcessorHandleImpl}
 import org.apache.spark.sql.streaming.{ListState, MapState, TimeMode, 
TTLConfig, ValueState}
 import org.apache.spark.sql.types.{BinaryType, StructType}
@@ -41,7 +40,7 @@ class MapStateSuite extends StateVariableSuiteBase {
     tryWithProviderResource(newStoreProviderWithStateVariable(true)) { 
provider =>
       val store = provider.getStore(0)
       val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(),
-        Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeMode.None())
+        stringEncoder, TimeMode.None())
 
       val testState: MapState[String, Double] =
         handle.getMapState[String, Double]("testState", Encoders.STRING, 
Encoders.scalaDouble)
@@ -75,7 +74,7 @@ class MapStateSuite extends StateVariableSuiteBase {
     tryWithProviderResource(newStoreProviderWithStateVariable(true)) { 
provider =>
       val store = provider.getStore(0)
       val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(),
-        Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeMode.None())
+        stringEncoder, TimeMode.None())
 
       val testState1: MapState[Long, Double] =
         handle.getMapState[Long, Double]("testState1", Encoders.scalaLong, 
Encoders.scalaDouble)
@@ -114,7 +113,7 @@ class MapStateSuite extends StateVariableSuiteBase {
     tryWithProviderResource(newStoreProviderWithStateVariable(true)) { 
provider =>
       val store = provider.getStore(0)
       val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(),
-        Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeMode.None())
+        stringEncoder, TimeMode.None())
 
       val mapTestState1: MapState[String, Int] =
         handle.getMapState[String, Int]("mapTestState1", Encoders.STRING, 
Encoders.scalaInt)
@@ -175,7 +174,7 @@ class MapStateSuite extends StateVariableSuiteBase {
       val store = provider.getStore(0)
       val timestampMs = 10
       val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(),
-        Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], 
TimeMode.ProcessingTime(),
+        stringEncoder, TimeMode.ProcessingTime(),
         batchTimestampMs = Some(timestampMs))
 
       val ttlConfig = TTLConfig(ttlDuration = Duration.ofMinutes(1))
@@ -196,7 +195,7 @@ class MapStateSuite extends StateVariableSuiteBase {
 
       // increment batchProcessingTime, or watermark and ensure expired value 
is not returned
       val nextBatchHandle = new StatefulProcessorHandleImpl(store, 
UUID.randomUUID(),
-        Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]],
+        stringEncoder,
         TimeMode.ProcessingTime(), batchTimestampMs = Some(ttlExpirationMs))
 
       val nextBatchTestState: MapStateImplWithTTL[String, String] =
@@ -233,7 +232,7 @@ class MapStateSuite extends StateVariableSuiteBase {
       val store = provider.getStore(0)
       val batchTimestampMs = 10
       val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(),
-        Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]],
+        stringEncoder,
         TimeMode.ProcessingTime(), batchTimestampMs = Some(batchTimestampMs))
 
       Seq(null, Duration.ZERO, Duration.ofMinutes(-1)).foreach { ttlDuration =>
@@ -261,7 +260,7 @@ class MapStateSuite extends StateVariableSuiteBase {
       val store = provider.getStore(0)
       val timestampMs = 10
       val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(),
-        Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], 
TimeMode.ProcessingTime(),
+        stringEncoder, TimeMode.ProcessingTime(),
         batchTimestampMs = Some(timestampMs))
 
       val ttlConfig = TTLConfig(ttlDuration = Duration.ofMinutes(1))
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatefulProcessorHandleSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatefulProcessorHandleSuite.scala
index e2940497e911..48a6fd836a46 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatefulProcessorHandleSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatefulProcessorHandleSuite.scala
@@ -22,7 +22,6 @@ import java.util.UUID
 
 import org.apache.spark.SparkUnsupportedOperationException
 import org.apache.spark.sql.Encoders
-import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
 import org.apache.spark.sql.execution.streaming.{ImplicitGroupingKeyTracker, 
StatefulProcessorHandleImpl, StatefulProcessorHandleState}
 import org.apache.spark.sql.streaming.{TimeMode, TTLConfig}
 
@@ -33,9 +32,6 @@ import org.apache.spark.sql.streaming.{TimeMode, TTLConfig}
  */
 class StatefulProcessorHandleSuite extends StateVariableSuiteBase {
 
-  private def keyExprEncoder: ExpressionEncoder[Any] =
-    Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]]
-
   private def getTimeMode(timeMode: String): TimeMode = {
     timeMode match {
       case "None" => TimeMode.None()
@@ -50,7 +46,7 @@ class StatefulProcessorHandleSuite extends 
StateVariableSuiteBase {
       tryWithProviderResource(newStoreProviderWithStateVariable(true)) { 
provider =>
         val store = provider.getStore(0)
         val handle = new StatefulProcessorHandleImpl(store,
-          UUID.randomUUID(), keyExprEncoder, getTimeMode(timeMode))
+          UUID.randomUUID(), stringEncoder, getTimeMode(timeMode))
         assert(handle.getHandleState === StatefulProcessorHandleState.CREATED)
         handle.getValueState[Long]("testState", Encoders.scalaLong)
       }
@@ -91,7 +87,7 @@ class StatefulProcessorHandleSuite extends 
StateVariableSuiteBase {
       tryWithProviderResource(newStoreProviderWithStateVariable(true)) { 
provider =>
         val store = provider.getStore(0)
         val handle = new StatefulProcessorHandleImpl(store,
-          UUID.randomUUID(), keyExprEncoder, getTimeMode(timeMode))
+          UUID.randomUUID(), stringEncoder, getTimeMode(timeMode))
 
         Seq(StatefulProcessorHandleState.INITIALIZED,
           StatefulProcessorHandleState.DATA_PROCESSED,
@@ -109,7 +105,7 @@ class StatefulProcessorHandleSuite extends 
StateVariableSuiteBase {
     tryWithProviderResource(newStoreProviderWithStateVariable(true)) { 
provider =>
       val store = provider.getStore(0)
       val handle = new StatefulProcessorHandleImpl(store,
-        UUID.randomUUID(), keyExprEncoder, TimeMode.None())
+        UUID.randomUUID(), stringEncoder, TimeMode.None())
       val ex = intercept[SparkUnsupportedOperationException] {
         handle.registerTimer(10000L)
       }
@@ -145,7 +141,7 @@ class StatefulProcessorHandleSuite extends 
StateVariableSuiteBase {
       tryWithProviderResource(newStoreProviderWithStateVariable(true)) { 
provider =>
         val store = provider.getStore(0)
         val handle = new StatefulProcessorHandleImpl(store,
-          UUID.randomUUID(), keyExprEncoder, getTimeMode(timeMode))
+          UUID.randomUUID(), stringEncoder, getTimeMode(timeMode))
         handle.setHandleState(StatefulProcessorHandleState.INITIALIZED)
         assert(handle.getHandleState === 
StatefulProcessorHandleState.INITIALIZED)
 
@@ -166,7 +162,7 @@ class StatefulProcessorHandleSuite extends 
StateVariableSuiteBase {
       tryWithProviderResource(newStoreProviderWithStateVariable(true)) { 
provider =>
         val store = provider.getStore(0)
         val handle = new StatefulProcessorHandleImpl(store,
-          UUID.randomUUID(), keyExprEncoder, getTimeMode(timeMode))
+          UUID.randomUUID(), stringEncoder, getTimeMode(timeMode))
         handle.setHandleState(StatefulProcessorHandleState.DATA_PROCESSED)
         assert(handle.getHandleState === 
StatefulProcessorHandleState.DATA_PROCESSED)
 
@@ -206,7 +202,7 @@ class StatefulProcessorHandleSuite extends 
StateVariableSuiteBase {
       tryWithProviderResource(newStoreProviderWithStateVariable(true)) { 
provider =>
         val store = provider.getStore(0)
         val handle = new StatefulProcessorHandleImpl(store,
-          UUID.randomUUID(), keyExprEncoder, getTimeMode(timeMode))
+          UUID.randomUUID(), stringEncoder, getTimeMode(timeMode))
 
         Seq(StatefulProcessorHandleState.CREATED,
           StatefulProcessorHandleState.TIMER_PROCESSED,
@@ -223,7 +219,7 @@ class StatefulProcessorHandleSuite extends 
StateVariableSuiteBase {
     tryWithProviderResource(newStoreProviderWithStateVariable(true)) { 
provider =>
       val store = provider.getStore(0)
       val handle = new StatefulProcessorHandleImpl(store,
-        UUID.randomUUID(), keyExprEncoder, TimeMode.ProcessingTime(),
+        UUID.randomUUID(), stringEncoder, TimeMode.ProcessingTime(),
         batchTimestampMs = Some(10))
 
       val valueStateWithTTL = handle.getValueState("testState",
@@ -241,7 +237,7 @@ class StatefulProcessorHandleSuite extends 
StateVariableSuiteBase {
     tryWithProviderResource(newStoreProviderWithStateVariable(true)) { 
provider =>
       val store = provider.getStore(0)
       val handle = new StatefulProcessorHandleImpl(store,
-        UUID.randomUUID(), keyExprEncoder, TimeMode.ProcessingTime(),
+        UUID.randomUUID(), stringEncoder, TimeMode.ProcessingTime(),
         batchTimestampMs = Some(10))
 
       val listStateWithTTL = handle.getListState("testState",
@@ -259,7 +255,7 @@ class StatefulProcessorHandleSuite extends 
StateVariableSuiteBase {
     tryWithProviderResource(newStoreProviderWithStateVariable(true)) { 
provider =>
       val store = provider.getStore(0)
       val handle = new StatefulProcessorHandleImpl(store,
-        UUID.randomUUID(), keyExprEncoder, TimeMode.ProcessingTime(),
+        UUID.randomUUID(), stringEncoder, TimeMode.ProcessingTime(),
         batchTimestampMs = Some(10))
 
       val mapStateWithTTL = handle.getMapState("testState",
@@ -277,7 +273,7 @@ class StatefulProcessorHandleSuite extends 
StateVariableSuiteBase {
     tryWithProviderResource(newStoreProviderWithStateVariable(true)) { 
provider =>
       val store = provider.getStore(0)
       val handle = new StatefulProcessorHandleImpl(store,
-        UUID.randomUUID(), keyExprEncoder, TimeMode.None())
+        UUID.randomUUID(), stringEncoder, TimeMode.None())
 
       handle.getValueState("testValueState", Encoders.STRING)
       handle.getListState("testListState", Encoders.STRING)
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/TimerSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/TimerSuite.scala
index df6a3fd7b23e..24a120be9d9a 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/TimerSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/TimerSuite.scala
@@ -18,7 +18,7 @@
 package org.apache.spark.sql.execution.streaming.state
 
 import org.apache.spark.sql.Encoders
-import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder}
 import org.apache.spark.sql.execution.streaming.{ImplicitGroupingKeyTracker, 
TimerStateImpl}
 import org.apache.spark.sql.streaming.TimeMode
 
@@ -45,7 +45,7 @@ class TimerSuite extends StateVariableSuiteBase {
 
       ImplicitGroupingKeyTracker.setImplicitKey("test_key")
       val timerState = new TimerStateImpl(store, timeMode,
-        Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]])
+        stringEncoder)
       timerState.registerTimer(1L * 1000)
       assert(timerState.listTimers().toSet === Set(1000L))
       assert(timerState.getExpiredTimers(Long.MaxValue).toSeq === 
Seq(("test_key", 1000L)))
@@ -64,9 +64,9 @@ class TimerSuite extends StateVariableSuiteBase {
 
       ImplicitGroupingKeyTracker.setImplicitKey("test_key")
       val timerState1 = new TimerStateImpl(store, timeMode,
-        Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]])
+        stringEncoder)
       val timerState2 = new TimerStateImpl(store, timeMode,
-        Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]])
+        stringEncoder)
       timerState1.registerTimer(1L * 1000)
       timerState2.registerTimer(15L * 1000)
       assert(timerState1.listTimers().toSet === Set(15000L, 1000L))
@@ -89,7 +89,7 @@ class TimerSuite extends StateVariableSuiteBase {
 
       ImplicitGroupingKeyTracker.setImplicitKey("test_key1")
       val timerState1 = new TimerStateImpl(store, timeMode,
-        Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]])
+        stringEncoder)
       timerState1.registerTimer(1L * 1000)
       timerState1.registerTimer(2L * 1000)
       assert(timerState1.listTimers().toSet === Set(1000L, 2000L))
@@ -97,7 +97,7 @@ class TimerSuite extends StateVariableSuiteBase {
 
       ImplicitGroupingKeyTracker.setImplicitKey("test_key2")
       val timerState2 = new TimerStateImpl(store, timeMode,
-        Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]])
+        stringEncoder)
       timerState2.registerTimer(15L * 1000)
       ImplicitGroupingKeyTracker.removeImplicitKey()
 
@@ -122,7 +122,7 @@ class TimerSuite extends StateVariableSuiteBase {
 
       ImplicitGroupingKeyTracker.setImplicitKey("test_key")
       val timerState = new TimerStateImpl(store, timeMode,
-        Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]])
+        stringEncoder)
       val timerTimerstamps = Seq(931L, 8000L, 452300L, 4200L, 90L, 1L, 2L, 8L, 
3L, 35L, 6L, 9L, 5L)
       // register/put unordered timestamp into rocksDB
       timerTimerstamps.foreach(timerState.registerTimer)
@@ -141,19 +141,19 @@ class TimerSuite extends StateVariableSuiteBase {
 
       ImplicitGroupingKeyTracker.setImplicitKey("test_key1")
       val timerState1 = new TimerStateImpl(store, timeMode,
-        Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]])
+        stringEncoder)
       val timerTimestamps1 = Seq(64L, 32L, 1024L, 4096L, 0L, 1L)
       timerTimestamps1.foreach(timerState1.registerTimer)
 
       val timerState2 = new TimerStateImpl(store, timeMode,
-        Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]])
+        stringEncoder)
       val timerTimestamps2 = Seq(931L, 8000L, 452300L, 4200L)
       timerTimestamps2.foreach(timerState2.registerTimer)
       ImplicitGroupingKeyTracker.removeImplicitKey()
 
       ImplicitGroupingKeyTracker.setImplicitKey("test_key3")
       val timerState3 = new TimerStateImpl(store, timeMode,
-        Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]])
+        stringEncoder)
       val timerTimerStamps3 = Seq(1L, 2L, 8L, 3L)
       timerTimerStamps3.foreach(timerState3.registerTimer)
       ImplicitGroupingKeyTracker.removeImplicitKey()
@@ -171,7 +171,7 @@ class TimerSuite extends StateVariableSuiteBase {
       val store = provider.getStore(0)
       ImplicitGroupingKeyTracker.setImplicitKey(TestClass(1L, "k1"))
       val timerState = new TimerStateImpl(store, timeMode,
-        Encoders.product[TestClass].asInstanceOf[ExpressionEncoder[Any]])
+        
encoderFor(Encoders.product[TestClass]).asInstanceOf[ExpressionEncoder[Any]])
 
       timerState.registerTimer(1L * 1000)
       assert(timerState.listTimers().toSet === Set(1000L))
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala
index 41912a4dda23..13d758eb1b88 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala
@@ -27,7 +27,7 @@ import org.scalatest.BeforeAndAfter
 
 import org.apache.spark.{SparkException, SparkUnsupportedOperationException}
 import org.apache.spark.sql.Encoders
-import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder}
 import org.apache.spark.sql.execution.streaming.{ImplicitGroupingKeyTracker, 
StatefulProcessorHandleImpl, ValueStateImplWithTTL}
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.streaming.{TimeMode, TTLConfig, ValueState}
@@ -49,7 +49,7 @@ class ValueStateSuite extends StateVariableSuiteBase {
     tryWithProviderResource(newStoreProviderWithStateVariable(true)) { 
provider =>
       val store = provider.getStore(0)
       val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(),
-        Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeMode.None())
+        stringEncoder, TimeMode.None())
 
       val stateName = "testState"
       val testState: ValueState[Long] = 
handle.getValueState[Long]("testState", Encoders.scalaLong)
@@ -93,7 +93,7 @@ class ValueStateSuite extends StateVariableSuiteBase {
     tryWithProviderResource(newStoreProviderWithStateVariable(true)) { 
provider =>
       val store = provider.getStore(0)
       val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(),
-        Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeMode.None())
+        stringEncoder, TimeMode.None())
 
       val testState: ValueState[Long] = 
handle.getValueState[Long]("testState", Encoders.scalaLong)
       ImplicitGroupingKeyTracker.setImplicitKey("test_key")
@@ -119,7 +119,7 @@ class ValueStateSuite extends StateVariableSuiteBase {
     tryWithProviderResource(newStoreProviderWithStateVariable(true)) { 
provider =>
       val store = provider.getStore(0)
       val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(),
-        Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeMode.None())
+        stringEncoder, TimeMode.None())
 
       val testState1: ValueState[Long] = handle.getValueState[Long](
         "testState1", Encoders.scalaLong)
@@ -164,7 +164,7 @@ class ValueStateSuite extends StateVariableSuiteBase {
     tryWithProviderResource(newStoreProviderWithStateVariable(true)) { 
provider =>
       val store = provider.getStore(0)
       val handle = new StatefulProcessorHandleImpl(store,
-        UUID.randomUUID(), 
Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeMode.None())
+        UUID.randomUUID(), stringEncoder, TimeMode.None())
 
       val cfName = "$testState"
       val ex = intercept[SparkUnsupportedOperationException] {
@@ -204,7 +204,7 @@ class ValueStateSuite extends StateVariableSuiteBase {
     tryWithProviderResource(newStoreProviderWithStateVariable(true)) { 
provider =>
       val store = provider.getStore(0)
       val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(),
-        Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeMode.None())
+        stringEncoder, TimeMode.None())
 
       val testState: ValueState[Double] = 
handle.getValueState[Double]("testState",
         Encoders.scalaDouble)
@@ -230,7 +230,7 @@ class ValueStateSuite extends StateVariableSuiteBase {
     tryWithProviderResource(newStoreProviderWithStateVariable(true)) { 
provider =>
       val store = provider.getStore(0)
       val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(),
-        Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeMode.None())
+        stringEncoder, TimeMode.None())
 
       val testState: ValueState[Long] = handle.getValueState[Long]("testState",
         Encoders.scalaLong)
@@ -256,7 +256,7 @@ class ValueStateSuite extends StateVariableSuiteBase {
     tryWithProviderResource(newStoreProviderWithStateVariable(true)) { 
provider =>
       val store = provider.getStore(0)
       val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(),
-        Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeMode.None())
+        stringEncoder, TimeMode.None())
 
       val testState: ValueState[TestClass] = 
handle.getValueState[TestClass]("testState",
         Encoders.product[TestClass])
@@ -282,7 +282,7 @@ class ValueStateSuite extends StateVariableSuiteBase {
     tryWithProviderResource(newStoreProviderWithStateVariable(true)) { 
provider =>
       val store = provider.getStore(0)
       val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(),
-        Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeMode.None())
+        stringEncoder, TimeMode.None())
 
       val testState: ValueState[POJOTestClass] = 
handle.getValueState[POJOTestClass]("testState",
         Encoders.bean(classOf[POJOTestClass]))
@@ -310,7 +310,7 @@ class ValueStateSuite extends StateVariableSuiteBase {
       val store = provider.getStore(0)
       val timestampMs = 10
       val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(),
-        Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], 
TimeMode.ProcessingTime(),
+        stringEncoder, TimeMode.ProcessingTime(),
         batchTimestampMs = Some(timestampMs))
 
       val ttlConfig = TTLConfig(ttlDuration = Duration.ofMinutes(1))
@@ -330,7 +330,7 @@ class ValueStateSuite extends StateVariableSuiteBase {
 
       // increment batchProcessingTime, or watermark and ensure expired value 
is not returned
       val nextBatchHandle = new StatefulProcessorHandleImpl(store, 
UUID.randomUUID(),
-        Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]],
+        stringEncoder,
         TimeMode.ProcessingTime(), batchTimestampMs = Some(ttlExpirationMs))
 
       val nextBatchTestState: ValueStateImplWithTTL[String] =
@@ -366,7 +366,7 @@ class ValueStateSuite extends StateVariableSuiteBase {
       val store = provider.getStore(0)
       val batchTimestampMs = 10
       val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(),
-        Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]],
+        stringEncoder,
         TimeMode.ProcessingTime(), batchTimestampMs = Some(batchTimestampMs))
 
       Seq(null, Duration.ZERO, Duration.ofMinutes(-1)).foreach { ttlDuration =>
@@ -393,8 +393,8 @@ class ValueStateSuite extends StateVariableSuiteBase {
       val store = provider.getStore(0)
       val timestampMs = 10
       val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(),
-        Encoders.product[TestClass].asInstanceOf[ExpressionEncoder[Any]], 
TimeMode.ProcessingTime(),
-        batchTimestampMs = Some(timestampMs))
+        
encoderFor(Encoders.product[TestClass]).asInstanceOf[ExpressionEncoder[Any]],
+        TimeMode.ProcessingTime(), batchTimestampMs = Some(timestampMs))
 
       val ttlConfig = TTLConfig(ttlDuration = Duration.ofMinutes(1))
       val testState: ValueStateImplWithTTL[POJOTestClass] =
@@ -437,6 +437,8 @@ abstract class StateVariableSuiteBase extends 
SharedSparkSession
 
   import StateStoreTestsHelper._
 
+  protected val stringEncoder = 
encoderFor(Encoders.STRING).asInstanceOf[ExpressionEncoder[Any]]
+
   // dummy schema for initializing rocksdb provider
   protected def schemaForKeyRow: StructType = new StructType().add("key", 
BinaryType)
   protected def schemaForValueRow: StructType = new StructType().add("value", 
BinaryType)


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to