This is an automated email from the ASF dual-hosted git repository.
dongjoon pushed a commit to branch branch-3.0
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.0 by this push:
new e7fef70 [SPARK-31450][SQL] Make ExpressionEncoder thread-safe
e7fef70 is described below
commit e7fef70fbbea08a38316abdaa9445123bb8c39e2
Author: herman <[email protected]>
AuthorDate: Thu Apr 16 18:47:46 2020 -0700
[SPARK-31450][SQL] Make ExpressionEncoder thread-safe
### What changes were proposed in this pull request?
This PR moves the `ExpressionEncoder.toRow` and `ExpressionEncoder.fromRow`
functions into their own function objects(`ExpressionEncoder.Serializer` &
`ExpressionEncoder.Deserializer`). This effectively makes the
`ExpressionEncoder` stateless, thread-safe and (more) reusable. The function
objects are not thread safe, however they are documented as such and should be
used in a more limited scope (making it easier to reason about thread safety).
### Why are the changes needed?
ExpressionEncoders are not thread-safe. We had various (nasty) bugs because
of this.
### Does this PR introduce any user-facing change?
No.
### How was this patch tested?
Existing tests.
Closes #28223 from hvanhovell/SPARK-31450.
Authored-by: herman <[email protected]>
Signed-off-by: Dongjoon Hyun <[email protected]>
(cherry picked from commit fab4ca5156d5e1cc0e976c7c27b28a12fa61eb6d)
Signed-off-by: Dongjoon Hyun <[email protected]>
---
.../spark/ml/source/image/ImageFileFormat.scala | 4 +-
.../spark/ml/source/libsvm/LibSVMRelation.scala | 4 +-
.../mllib/linalg/UDTSerializationBenchmark.scala | 10 +-
.../main/scala/org/apache/spark/sql/Encoder.scala | 9 +-
.../sql/catalyst/encoders/ExpressionEncoder.scala | 102 ++++++++++++++-------
.../spark/sql/catalyst/expressions/ScalaUDF.scala | 4 +-
.../scala/org/apache/spark/sql/HashBenchmark.scala | 4 +-
.../spark/sql/UnsafeProjectionBenchmark.scala | 4 +-
.../catalyst/encoders/EncoderResolutionSuite.scala | 28 ++++--
.../catalyst/encoders/ExpressionEncoderSuite.scala | 22 ++---
.../sql/catalyst/encoders/RowEncoderSuite.scala | 49 ++++++----
.../expressions/HashExpressionsSuite.scala | 6 +-
.../expressions/ObjectExpressionsSuite.scala | 4 +-
.../codegen/GenerateUnsafeRowJoinerSuite.scala | 4 +-
.../catalyst/util/ArrayDataIndexedSeqSuite.scala | 4 +-
.../spark/sql/catalyst/util/UnsafeArraySuite.scala | 74 +++++----------
.../main/scala/org/apache/spark/sql/Dataset.scala | 13 ++-
.../scala/org/apache/spark/sql/SparkSession.scala | 9 +-
.../spark/sql/execution/SparkStrategies.scala | 3 +-
.../spark/sql/execution/aggregate/udaf.scala | 15 +--
.../execution/datasources/DataSourceStrategy.scala | 4 +-
.../sql/execution/datasources/jdbc/JdbcUtils.scala | 4 +-
.../datasources/v2/DescribeNamespaceExec.scala | 6 +-
.../datasources/v2/DescribeTableExec.scala | 6 +-
.../datasources/v2/ShowCurrentNamespaceExec.scala | 11 ++-
.../datasources/v2/ShowNamespacesExec.scala | 6 +-
.../datasources/v2/ShowTablePropertiesExec.scala | 6 +-
.../execution/datasources/v2/ShowTablesExec.scala | 12 +--
.../continuous/ContinuousTextSocketSource.scala | 3 +-
.../spark/sql/execution/streaming/memory.scala | 8 +-
.../streaming/sources/ContinuousMemoryStream.scala | 2 +-
.../streaming/sources/ForeachBatchSink.scala | 3 +-
.../streaming/sources/ForeachWriterTable.scala | 2 +-
.../sql/execution/streaming/sources/memory.scala | 4 +-
.../apache/spark/sql/internal/CatalogImpl.scala | 3 +-
.../spark/sql/execution/GroupedIteratorSuite.scala | 15 ++-
.../benchmark/UnsafeArrayDataBenchmark.scala | 29 +++---
.../binaryfile/BinaryFileFormatSuite.scala | 2 +-
.../apache/spark/sql/streaming/StreamTest.scala | 22 +++--
39 files changed, 282 insertions(+), 238 deletions(-)
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/source/image/ImageFileFormat.scala
b/mllib/src/main/scala/org/apache/spark/ml/source/image/ImageFileFormat.scala
index c332144..4944e0c 100644
---
a/mllib/src/main/scala/org/apache/spark/ml/source/image/ImageFileFormat.scala
+++
b/mllib/src/main/scala/org/apache/spark/ml/source/image/ImageFileFormat.scala
@@ -91,8 +91,8 @@ private[image] class ImageFileFormat extends FileFormat with
DataSourceRegister
if (requiredSchema.isEmpty) {
filteredResult.map(_ => emptyUnsafeRow)
} else {
- val converter = RowEncoder(requiredSchema)
- filteredResult.map(row => converter.toRow(row))
+ val toRow = RowEncoder(requiredSchema).createSerializer()
+ filteredResult.map(row => toRow(row))
}
}
}
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
index 6ead4df..da8f3a2 100644
---
a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
+++
b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
@@ -166,7 +166,7 @@ private[libsvm] class LibSVMFileFormat
LabeledPoint(label, Vectors.sparse(numFeatures, indices, values))
}
- val converter = RowEncoder(dataSchema)
+ val toRow = RowEncoder(dataSchema).createSerializer()
val fullOutput = dataSchema.map { f =>
AttributeReference(f.name, f.dataType, f.nullable, f.metadata)()
}
@@ -178,7 +178,7 @@ private[libsvm] class LibSVMFileFormat
points.map { pt =>
val features = if (isSparse) pt.features.toSparse else
pt.features.toDense
- requiredColumns(converter.toRow(Row(pt.label, features)))
+ requiredColumns(toRow(Row(pt.label, features)))
}
}
}
diff --git
a/mllib/src/test/scala/org/apache/spark/mllib/linalg/UDTSerializationBenchmark.scala
b/mllib/src/test/scala/org/apache/spark/mllib/linalg/UDTSerializationBenchmark.scala
index 5f19e46..3caa8f6 100644
---
a/mllib/src/test/scala/org/apache/spark/mllib/linalg/UDTSerializationBenchmark.scala
+++
b/mllib/src/test/scala/org/apache/spark/mllib/linalg/UDTSerializationBenchmark.scala
@@ -38,12 +38,14 @@ object UDTSerializationBenchmark extends BenchmarkBase {
val iters = 1e2.toInt
val numRows = 1e3.toInt
- val encoder = ExpressionEncoder[Vector].resolveAndBind()
+ val encoder = ExpressionEncoder[Vector]().resolveAndBind()
+ val toRow = encoder.createSerializer()
+ val fromRow = encoder.createDeserializer()
val vectors = (1 to numRows).map { i =>
Vectors.dense(Array.fill(1e5.toInt)(1.0 * i))
}.toArray
- val rows = vectors.map(encoder.toRow)
+ val rows = vectors.map(toRow)
val benchmark = new Benchmark("VectorUDT de/serialization", numRows,
iters, output = output)
@@ -51,7 +53,7 @@ object UDTSerializationBenchmark extends BenchmarkBase {
var sum = 0
var i = 0
while (i < numRows) {
- sum += encoder.toRow(vectors(i)).numFields
+ sum += toRow(vectors(i)).numFields
i += 1
}
}
@@ -60,7 +62,7 @@ object UDTSerializationBenchmark extends BenchmarkBase {
var sum = 0
var i = 0
while (i < numRows) {
- sum += encoder.fromRow(rows(i)).numActives
+ sum += fromRow(rows(i)).numActives
i += 1
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala
index c43a86a..ea760d8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala
@@ -58,8 +58,7 @@ import org.apache.spark.sql.types._
* }}}
*
* == Implementation ==
- * - Encoders are not required to be thread-safe and thus they do not need to
use locks to guard
- * against concurrent access if they reuse internal buffers to improve
performance.
+ * - Encoders should be thread-safe.
*
* @since 1.6.0
*/
@@ -76,10 +75,4 @@ trait Encoder[T] extends Serializable {
* A ClassTag that can be used to construct an Array to contain a collection
of `T`.
*/
def clsTag: ClassTag[T]
-
- /**
- * Create a copied [[Encoder]]. The implementation may just copy internal
reusable fields to speed
- * up the [[Encoder]] creation.
- */
- def makeCopy: Encoder[T]
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
index bd49967..213fbcc 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
@@ -17,12 +17,15 @@
package org.apache.spark.sql.catalyst.encoders
+import java.io.ObjectInputStream
+
import scala.reflect.ClassTag
import scala.reflect.runtime.universe.{typeTag, TypeTag}
import org.apache.spark.sql.Encoder
import org.apache.spark.sql.catalyst.{InternalRow, JavaTypeInference,
ScalaReflection}
import org.apache.spark.sql.catalyst.analysis.{Analyzer, GetColumnByOrdinal,
SimpleAnalyzer, UnresolvedAttribute, UnresolvedExtractValue}
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder.{Deserializer,
Serializer}
import org.apache.spark.sql.catalyst.expressions._
import
org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull,
InitializeJavaBean, Invoke, NewInstance}
@@ -162,6 +165,56 @@ object ExpressionEncoder {
e4: ExpressionEncoder[T4],
e5: ExpressionEncoder[T5]): ExpressionEncoder[(T1, T2, T3, T4, T5)] =
tuple(Seq(e1, e2, e3, e4, e5)).asInstanceOf[ExpressionEncoder[(T1, T2, T3,
T4, T5)]]
+
+ private val anyObjectType = ObjectType(classOf[Any])
+
+ /**
+ * Function that deserializes an [[InternalRow]] into an object of type `T`.
This class is not
+ * thread-safe.
+ */
+ class Deserializer[T](private val expressions: Seq[Expression])
+ extends (InternalRow => T) with Serializable {
+ @transient
+ private[this] var constructProjection: Projection = _
+
+ override def apply(row: InternalRow): T = try {
+ if (constructProjection == null) {
+ constructProjection = SafeProjection.create(expressions)
+ }
+ constructProjection(row).get(0, anyObjectType).asInstanceOf[T]
+ } catch {
+ case e: Exception =>
+ throw new RuntimeException(s"Error while decoding: $e\n" +
+
s"${expressions.map(_.simpleString(SQLConf.get.maxToStringFields)).mkString("\n")}",
e)
+ }
+ }
+
+ /**
+ * Function that serializesa an object of type `T` to an [[InternalRow]].
This class is not
+ * thread-safe. Note that multiple calls to `apply(..)` return the same
actual [[InternalRow]]
+ * object. Thus, the caller should copy the result before making another
call if required.
+ */
+ class Serializer[T](private val expressions: Seq[Expression])
+ extends (T => InternalRow) with Serializable {
+ @transient
+ private[this] var inputRow: GenericInternalRow = _
+
+ @transient
+ private[this] var extractProjection: UnsafeProjection = _
+
+ override def apply(t: T): InternalRow = try {
+ if (extractProjection == null) {
+ inputRow = new GenericInternalRow(1)
+ extractProjection = GenerateUnsafeProjection.generate(expressions)
+ }
+ inputRow(0) = t
+ extractProjection(inputRow)
+ } catch {
+ case e: Exception =>
+ throw new RuntimeException(s"Error while encoding: $e\n" +
+
s"${expressions.map(_.simpleString(SQLConf.get.maxToStringFields)).mkString("\n")}",
e)
+ }
+ }
}
/**
@@ -301,25 +354,22 @@ case class ExpressionEncoder[T](
}
@transient
- private lazy val extractProjection = GenerateUnsafeProjection.generate({
+ private lazy val optimizedDeserializer: Seq[Expression] = {
// When using `ExpressionEncoder` directly, we will skip the normal query
processing steps
// (analyzer, optimizer, etc.). Here we apply the ReassignLambdaVariableID
rule, as it's
// important to codegen performance.
- val optimizedPlan =
ReassignLambdaVariableID.apply(DummyExpressionHolder(serializer))
+ val optimizedPlan =
ReassignLambdaVariableID.apply(DummyExpressionHolder(Seq(deserializer)))
optimizedPlan.asInstanceOf[DummyExpressionHolder].exprs
- })
-
- @transient
- private lazy val inputRow = new GenericInternalRow(1)
+ }
@transient
- private lazy val constructProjection = SafeProjection.create({
+ private lazy val optimizedSerializer = {
// When using `ExpressionEncoder` directly, we will skip the normal query
processing steps
// (analyzer, optimizer, etc.). Here we apply the ReassignLambdaVariableID
rule, as it's
// important to codegen performance.
- val optimizedPlan =
ReassignLambdaVariableID.apply(DummyExpressionHolder(Seq(deserializer)))
+ val optimizedPlan =
ReassignLambdaVariableID.apply(DummyExpressionHolder(serializer))
optimizedPlan.asInstanceOf[DummyExpressionHolder].exprs
- })
+ }
/**
* Returns a new set (with unique ids) of [[NamedExpression]] that represent
the serialized form
@@ -331,31 +381,21 @@ case class ExpressionEncoder[T](
}
/**
- * Returns an encoded version of `t` as a Spark SQL row. Note that multiple
calls to
- * toRow are allowed to return the same actual [[InternalRow]] object.
Thus, the caller should
- * copy the result before making another call if required.
+ * Create a serializer that can convert an object of type `T` to a Spark SQL
Row.
+ *
+ * Note that the returned [[Serializer]] is not thread safe. Multiple calls
to
+ * `serializer.apply(..)` are allowed to return the same actual
[[InternalRow]] object. Thus,
+ * the caller should copy the result before making another call if required.
*/
- def toRow(t: T): InternalRow = try {
- inputRow(0) = t
- extractProjection(inputRow)
- } catch {
- case e: Exception =>
- throw new RuntimeException(s"Error while encoding: $e\n" +
-
s"${serializer.map(_.simpleString(SQLConf.get.maxToStringFields)).mkString("\n")}",
e)
- }
+ def createSerializer(): Serializer[T] = new
Serializer[T](optimizedSerializer)
/**
- * Returns an object of type `T`, extracting the required values from the
provided row. Note that
- * you must `resolveAndBind` an encoder to a specific schema before you can
call this
- * function.
+ * Create a deserializer that can convert a Spark SQL Row into an object of
type `T`.
+ *
+ * Note that you must `resolveAndBind` an encoder to a specific schema
before you can create a
+ * deserializer.
*/
- def fromRow(row: InternalRow): T = try {
- constructProjection(row).get(0,
ObjectType(clsTag.runtimeClass)).asInstanceOf[T]
- } catch {
- case e: Exception =>
- throw new RuntimeException(s"Error while decoding: $e\n" +
- s"${deserializer.simpleString(SQLConf.get.maxToStringFields)}", e)
- }
+ def createDeserializer(): Deserializer[T] = new
Deserializer[T](optimizedDeserializer)
/**
* The process of resolution to a given schema throws away information about
where a given field
@@ -382,8 +422,6 @@ case class ExpressionEncoder[T](
.map { case(f, a) => s"${f.name}$a:
${f.dataType.simpleString}"}.mkString(", ")
override def toString: String = s"class[$schemaString]"
-
- override def makeCopy: ExpressionEncoder[T] = copy()
}
// A dummy logical plan that can hold expressions and go through optimizer
rules.
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
index 1ac7ca6..e80f03e 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
@@ -110,8 +110,8 @@ case class ScalaUDF(
} else {
val encoder = inputEncoders(i)
if (encoder.isDefined && encoder.get.isSerializedAsStructForTopLevel) {
- val enc = encoder.get.resolveAndBind()
- row: Any => enc.fromRow(row.asInstanceOf[InternalRow])
+ val fromRow = encoder.get.resolveAndBind().createDeserializer()
+ row: Any => fromRow(row.asInstanceOf[InternalRow])
} else {
CatalystTypeConverters.createToScalaConverter(dataType)
}
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/HashBenchmark.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/HashBenchmark.scala
index 3b4b80d..3f0121b 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/HashBenchmark.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/HashBenchmark.scala
@@ -41,13 +41,13 @@ object HashBenchmark extends BenchmarkBase {
def test(name: String, schema: StructType, numRows: Int, iters: Int): Unit =
{
runBenchmark(name) {
val generator = RandomDataGenerator.forType(schema, nullable = false).get
- val encoder = RowEncoder(schema)
+ val toRow = RowEncoder(schema).createSerializer()
val attrs = schema.toAttributes
val safeProjection = GenerateSafeProjection.generate(attrs, attrs)
val rows = (1 to numRows).map(_ =>
// The output of encoder is UnsafeRow, use safeProjection to turn in
into safe format.
- safeProjection(encoder.toRow(generator().asInstanceOf[Row])).copy()
+ safeProjection(toRow(generator().asInstanceOf[Row])).copy()
).toArray
val benchmark = new Benchmark("Hash For " + name, iters *
numRows.toLong, output = output)
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/UnsafeProjectionBenchmark.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/UnsafeProjectionBenchmark.scala
index 42a4cfc..950e313 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/UnsafeProjectionBenchmark.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/UnsafeProjectionBenchmark.scala
@@ -37,8 +37,8 @@ object UnsafeProjectionBenchmark extends BenchmarkBase {
def generateRows(schema: StructType, numRows: Int): Array[InternalRow] = {
val generator = RandomDataGenerator.forType(schema, nullable = false).get
- val encoder = RowEncoder(schema)
- (1 to numRows).map(_ =>
encoder.toRow(generator().asInstanceOf[Row]).copy()).toArray
+ val toRow = RowEncoder(schema).createSerializer()
+ (1 to numRows).map(_ =>
toRow(generator().asInstanceOf[Row]).copy()).toArray
}
override def runBenchmarkSuite(mainArgs: Array[String]): Unit = {
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala
index 53cb8bc..48f4ef5 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala
@@ -22,6 +22,7 @@ import scala.reflect.runtime.universe.TypeTag
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.util.GenericArrayData
import org.apache.spark.sql.types._
@@ -42,22 +43,29 @@ case class NestedArrayClass(nestedArr: Array[ArrayClass])
class EncoderResolutionSuite extends PlanTest {
private val str = UTF8String.fromString("hello")
+ def testFromRow[T](
+ encoder: ExpressionEncoder[T],
+ attributes: Seq[Attribute],
+ row: InternalRow): Unit = {
+ encoder.resolveAndBind(attributes).createDeserializer().apply(row)
+ }
+
test("real type doesn't match encoder schema but they are compatible:
product") {
val encoder = ExpressionEncoder[StringLongClass]
// int type can be up cast to long type
val attrs1 = Seq('a.string, 'b.int)
- encoder.resolveAndBind(attrs1).fromRow(InternalRow(str, 1))
+ testFromRow(encoder, attrs1, InternalRow(str, 1))
// int type can be up cast to string type
val attrs2 = Seq('a.int, 'b.long)
- encoder.resolveAndBind(attrs2).fromRow(InternalRow(1, 2L))
+ testFromRow(encoder, attrs2, InternalRow(1, 2L))
}
test("real type doesn't match encoder schema but they are compatible: nested
product") {
val encoder = ExpressionEncoder[ComplexClass]
val attrs = Seq('a.int, 'b.struct('a.int, 'b.long))
- encoder.resolveAndBind(attrs).fromRow(InternalRow(1, InternalRow(2, 3L)))
+ testFromRow(encoder, attrs, InternalRow(1, InternalRow(2, 3L)))
}
test("real type doesn't match encoder schema but they are compatible: tupled
encoder") {
@@ -65,14 +73,14 @@ class EncoderResolutionSuite extends PlanTest {
ExpressionEncoder[StringLongClass],
ExpressionEncoder[Long])
val attrs = Seq('a.struct('a.string, 'b.byte), 'b.int)
- encoder.resolveAndBind(attrs).fromRow(InternalRow(InternalRow(str,
1.toByte), 2))
+ testFromRow(encoder, attrs, InternalRow(InternalRow(str, 1.toByte), 2))
}
test("real type doesn't match encoder schema but they are compatible:
primitive array") {
val encoder = ExpressionEncoder[PrimitiveArrayClass]
val attrs = Seq('arr.array(IntegerType))
val array = new GenericArrayData(Array(1, 2, 3))
- encoder.resolveAndBind(attrs).fromRow(InternalRow(array))
+ testFromRow(encoder, attrs, InternalRow(array))
}
test("the real type is not compatible with encoder schema: primitive array")
{
@@ -93,7 +101,7 @@ class EncoderResolutionSuite extends PlanTest {
val encoder = ExpressionEncoder[ArrayClass]
val attrs = Seq('arr.array(new StructType().add("a", "int").add("b",
"int").add("c", "int")))
val array = new GenericArrayData(Array(InternalRow(1, 2, 3)))
- encoder.resolveAndBind(attrs).fromRow(InternalRow(array))
+ testFromRow(encoder, attrs, InternalRow(array))
}
test("real type doesn't match encoder schema but they are compatible: nested
array") {
@@ -103,7 +111,7 @@ class EncoderResolutionSuite extends PlanTest {
val attrs = Seq('nestedArr.array(et))
val innerArr = new GenericArrayData(Array(InternalRow(1, 2, 3)))
val outerArr = new GenericArrayData(Array(InternalRow(innerArr)))
- encoder.resolveAndBind(attrs).fromRow(InternalRow(outerArr))
+ testFromRow(encoder, attrs, InternalRow(outerArr))
}
test("the real type is not compatible with encoder schema: non-array field")
{
@@ -142,14 +150,14 @@ class EncoderResolutionSuite extends PlanTest {
val attrs = 'a.array(IntegerType) :: Nil
// It should pass analysis
- val bound = encoder.resolveAndBind(attrs)
+ val fromRow = encoder.resolveAndBind(attrs).createDeserializer()
// If no null values appear, it should work fine
- bound.fromRow(InternalRow(new GenericArrayData(Array(1, 2))))
+ fromRow(InternalRow(new GenericArrayData(Array(1, 2))))
// If there is null value, it should throw runtime exception
val e = intercept[RuntimeException] {
- bound.fromRow(InternalRow(new GenericArrayData(Array(1, null))))
+ fromRow(InternalRow(new GenericArrayData(Array(1, null))))
}
assert(e.getMessage.contains("Null value appeared in non-nullable field"))
}
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
index 1036dc7..6a094d4 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
@@ -369,14 +369,14 @@ class ExpressionEncoderSuite extends
CodegenInterpretedPlanTest with AnalysisTes
}
test("null check for map key: String") {
- val encoder = ExpressionEncoder[Map[String, Int]]()
- val e = intercept[RuntimeException](encoder.toRow(Map(("a", 1), (null,
2))))
+ val toRow = ExpressionEncoder[Map[String, Int]]().createSerializer()
+ val e = intercept[RuntimeException](toRow(Map(("a", 1), (null, 2))))
assert(e.getMessage.contains("Cannot use null as map key"))
}
test("null check for map key: Integer") {
- val encoder = ExpressionEncoder[Map[Integer, String]]()
- val e = intercept[RuntimeException](encoder.toRow(Map((1, "a"), (null,
"b"))))
+ val toRow = ExpressionEncoder[Map[Integer, String]]().createSerializer()
+ val e = intercept[RuntimeException](toRow(Map((1, "a"), (null, "b"))))
assert(e.getMessage.contains("Cannot use null as map key"))
}
@@ -436,10 +436,6 @@ class ExpressionEncoderSuite extends
CodegenInterpretedPlanTest with AnalysisTes
testOverflowingBigNumeric(BigInt("9" * 100), "scala very large big int")
testOverflowingBigNumeric(new BigInteger("9" * 100), "java very big int")
- encodeDecodeTest("foo" -> 1L, "makeCopy") {
- Encoders.product[(String,
Long)].makeCopy.asInstanceOf[ExpressionEncoder[(String, Long)]]
- }
-
private def testOverflowingBigNumeric[T: TypeTag](bigNumeric: T, testName:
String): Unit = {
Seq(true, false).foreach { ansiEnabled =>
testAndVerifyNotLeakingReflectionObjects(
@@ -450,12 +446,14 @@ class ExpressionEncoderSuite extends
CodegenInterpretedPlanTest with AnalysisTes
// Need to construct Encoder here rather than implicitly resolving it
// so that SQLConf changes are respected.
val encoder = ExpressionEncoder[T]()
+ val toRow = encoder.createSerializer()
if (!ansiEnabled) {
- val convertedBack =
encoder.resolveAndBind().fromRow(encoder.toRow(bigNumeric))
+ val fromRow = encoder.resolveAndBind().createDeserializer()
+ val convertedBack = fromRow(toRow(bigNumeric))
assert(convertedBack === null)
} else {
val e = intercept[RuntimeException] {
- encoder.toRow(bigNumeric)
+ toRow(bigNumeric)
}
assert(e.getMessage.contains("Error while encoding"))
assert(e.getCause.getClass === classOf[ArithmeticException])
@@ -474,10 +472,10 @@ class ExpressionEncoderSuite extends
CodegenInterpretedPlanTest with AnalysisTes
// Make sure encoder is serializable.
ClosureCleaner.clean((s: String) => encoder.getClass.getName)
- val row = encoder.toRow(input)
+ val row = encoder.createSerializer().apply(input)
val schema = encoder.schema.toAttributes
val boundEncoder = encoder.resolveAndBind()
- val convertedBack = try boundEncoder.fromRow(row) catch {
+ val convertedBack = try boundEncoder.createDeserializer().apply(row)
catch {
case e: Exception =>
fail(
s"""Exception thrown while decoding
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
index 1a1cab8..c1158e0 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.encoders
import scala.util.Random
import org.apache.spark.sql.{RandomDataGenerator, Row}
+import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.plans.CodegenInterpretedPlanTest
import org.apache.spark.sql.catalyst.util.{ArrayData, DateTimeUtils,
GenericArrayData}
import org.apache.spark.sql.internal.SQLConf
@@ -81,6 +82,18 @@ class RowEncoderSuite extends CodegenInterpretedPlanTest {
private val mapOfString = MapType(StringType, StringType)
private val arrayOfUDT = ArrayType(new ExamplePointUDT, false)
+ private def toRow(encoder: ExpressionEncoder[Row], row: Row): InternalRow = {
+ encoder.createSerializer().apply(row)
+ }
+
+ private def fromRow(encoder: ExpressionEncoder[Row], row: InternalRow): Row
= {
+ encoder.createDeserializer().apply(row)
+ }
+
+ private def roundTrip(encoder: ExpressionEncoder[Row], row: Row): Row = {
+ fromRow(encoder, toRow(encoder, row))
+ }
+
encodeDecodeTest(
new StructType()
.add("null", NullType)
@@ -144,8 +157,7 @@ class RowEncoderSuite extends CodegenInterpretedPlanTest {
val catalystDecimal = Decimal("1234.5678")
val input = Row(100, "test", 0.123, javaDecimal, scalaDecimal,
catalystDecimal)
- val row = encoder.toRow(input)
- val convertedBack = encoder.fromRow(row)
+ val convertedBack = roundTrip(encoder, input)
// Decimal will be converted back to Java BigDecimal when decoding.
assert(convertedBack.getDecimal(3).compareTo(javaDecimal) == 0)
assert(convertedBack.getDecimal(4).compareTo(scalaDecimal.bigDecimal) == 0)
@@ -157,7 +169,7 @@ class RowEncoderSuite extends CodegenInterpretedPlanTest {
val encoder = RowEncoder(schema).resolveAndBind()
val decimal = Decimal("67123.45")
val input = Row(decimal)
- val row = encoder.toRow(input)
+ val row = toRow(encoder, input)
assert(row.toSeq(schema).head == decimal)
}
@@ -172,7 +184,7 @@ class RowEncoderSuite extends CodegenInterpretedPlanTest {
withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") {
val encoder = RowEncoder(schema).resolveAndBind()
intercept[Exception] {
- encoder.toRow(row)
+ toRow(encoder, row)
} match {
case e: ArithmeticException =>
assert(e.getMessage.contains("cannot be represented as Decimal"))
@@ -184,7 +196,7 @@ class RowEncoderSuite extends CodegenInterpretedPlanTest {
withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") {
val encoder = RowEncoder(schema).resolveAndBind()
- assert(encoder.fromRow(encoder.toRow(row)).get(0) == null)
+ assert(roundTrip(encoder, row).get(0) == null)
}
}
@@ -237,8 +249,7 @@ class RowEncoderSuite extends CodegenInterpretedPlanTest {
Array(1.1.toFloat, 123.456.toFloat, Float.MaxValue),
Array(11.1111, 123456.7890123, Double.MaxValue)
)
- val row = encoder.toRow(Row.fromSeq(input))
- val convertedBack = encoder.fromRow(row)
+ val convertedBack = roundTrip(encoder, Row.fromSeq(input))
input.zipWithIndex.map { case (array, index) =>
assert(convertedBack.getSeq(index) === array)
}
@@ -254,8 +265,7 @@ class RowEncoderSuite extends CodegenInterpretedPlanTest {
Array(1, 2, null),
Array(Array("abc", null), null),
Array(Seq(Array(0L, null), null), null))
- val row = encoder.toRow(input)
- val convertedBack = encoder.fromRow(row)
+ val convertedBack = roundTrip(encoder, input)
assert(convertedBack.getSeq(0) == Seq(1, 2, null))
assert(convertedBack.getSeq(1) == Seq(Seq("abc", null), null))
assert(convertedBack.getSeq(2) == Seq(Seq(Seq(0L, null), null), null))
@@ -264,7 +274,7 @@ class RowEncoderSuite extends CodegenInterpretedPlanTest {
test("RowEncoder should throw RuntimeException if input row object is null")
{
val schema = new StructType().add("int", IntegerType)
val encoder = RowEncoder(schema)
- val e = intercept[RuntimeException](encoder.toRow(null))
+ val e = intercept[RuntimeException](toRow(encoder, null))
assert(e.getMessage.contains("Null value appeared in non-nullable field"))
assert(e.getMessage.contains("top level Product or row object"))
}
@@ -273,14 +283,14 @@ class RowEncoderSuite extends CodegenInterpretedPlanTest {
val e1 = intercept[RuntimeException] {
val schema = new StructType().add("a", IntegerType)
val encoder = RowEncoder(schema)
- encoder.toRow(Row(1.toShort))
+ toRow(encoder, Row(1.toShort))
}
assert(e1.getMessage.contains("java.lang.Short is not a valid external
type"))
val e2 = intercept[RuntimeException] {
val schema = new StructType().add("a", StringType)
val encoder = RowEncoder(schema)
- encoder.toRow(Row(1))
+ toRow(encoder, Row(1))
}
assert(e2.getMessage.contains("java.lang.Integer is not a valid external
type"))
@@ -288,14 +298,14 @@ class RowEncoderSuite extends CodegenInterpretedPlanTest {
val schema = new StructType().add("a",
new StructType().add("b", IntegerType).add("c", StringType))
val encoder = RowEncoder(schema)
- encoder.toRow(Row(1 -> "a"))
+ toRow(encoder, Row(1 -> "a"))
}
assert(e3.getMessage.contains("scala.Tuple2 is not a valid external type"))
val e4 = intercept[RuntimeException] {
val schema = new StructType().add("a", ArrayType(TimestampType))
val encoder = RowEncoder(schema)
- encoder.toRow(Row(Array("a")))
+ toRow(encoder, Row(Array("a")))
}
assert(e4.getMessage.contains("java.lang.String is not a valid external
type"))
}
@@ -313,9 +323,9 @@ class RowEncoderSuite extends CodegenInterpretedPlanTest {
val schema = new StructType().add("t", TimestampType)
val encoder = RowEncoder(schema).resolveAndBind()
val instant = java.time.Instant.parse("2019-02-26T16:56:00Z")
- val row = encoder.toRow(Row(instant))
+ val row = toRow(encoder, Row(instant))
assert(row.getLong(0) === DateTimeUtils.instantToMicros(instant))
- val readback = encoder.fromRow(row)
+ val readback = fromRow(encoder, row)
assert(readback.get(0) === instant)
}
}
@@ -325,9 +335,9 @@ class RowEncoderSuite extends CodegenInterpretedPlanTest {
val schema = new StructType().add("d", DateType)
val encoder = RowEncoder(schema).resolveAndBind()
val localDate = java.time.LocalDate.parse("2019-02-27")
- val row = encoder.toRow(Row(localDate))
+ val row = toRow(encoder, Row(localDate))
assert(row.getLong(0) === DateTimeUtils.localDateToDays(localDate))
- val readback = encoder.fromRow(row)
+ val readback = fromRow(encoder, row)
assert(readback.get(0).equals(localDate))
}
}
@@ -374,8 +384,7 @@ class RowEncoderSuite extends CodegenInterpretedPlanTest {
try {
for (_ <- 1 to 5) {
input = inputGenerator.apply().asInstanceOf[Row]
- val row = encoder.toRow(input)
- val convertedBack = encoder.fromRow(row)
+ val convertedBack = roundTrip(encoder, input)
assert(input == convertedBack)
}
} catch {
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala
index 68da1fa..af6e5a3 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala
@@ -699,11 +699,11 @@ class HashExpressionsSuite extends SparkFunSuite with
ExpressionEvalHelper {
private def testHash(inputSchema: StructType): Unit = {
val inputGenerator = RandomDataGenerator.forType(inputSchema, nullable =
false).get
- val encoder = RowEncoder(inputSchema)
+ val toRow = RowEncoder(inputSchema).createSerializer()
val seed = scala.util.Random.nextInt()
test(s"murmur3/xxHash64/hive hash: ${inputSchema.simpleString}") {
for (_ <- 1 to 10) {
- val input =
encoder.toRow(inputGenerator.apply().asInstanceOf[Row]).asInstanceOf[UnsafeRow]
+ val input =
toRow(inputGenerator.apply().asInstanceOf[Row]).asInstanceOf[UnsafeRow]
val literals =
input.toSeq(inputSchema).zip(inputSchema.map(_.dataType)).map {
case (value, dt) => Literal.create(value, dt)
}
@@ -717,7 +717,7 @@ class HashExpressionsSuite extends SparkFunSuite with
ExpressionEvalHelper {
val longSeed = Math.abs(seed).toLong + Integer.MAX_VALUE.toLong
test(s"SPARK-30633: xxHash64 with long seed: ${inputSchema.simpleString}")
{
for (_ <- 1 to 10) {
- val input =
encoder.toRow(inputGenerator.apply().asInstanceOf[Row]).asInstanceOf[UnsafeRow]
+ val input =
toRow(inputGenerator.apply().asInstanceOf[Row]).asInstanceOf[UnsafeRow]
val literals =
input.toSeq(inputSchema).zip(inputSchema.map(_.dataType)).map {
case (value, dt) => Literal.create(value, dt)
}
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
index ef7764d..c401493 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
@@ -445,8 +445,8 @@ class ObjectExpressionsSuite extends SparkFunSuite with
ExpressionEvalHelper {
testTypes.foreach { dt =>
genSchema(dt).map { schema =>
val row = RandomDataGenerator.randomRow(random, schema)
- val rowConverter = RowEncoder(schema)
- val internalRow = rowConverter.toRow(row)
+ val toRow = RowEncoder(schema).createSerializer()
+ val internalRow = toRow(row)
val lambda = LambdaVariable("dummy", schema(0).dataType,
schema(0).nullable, id = 0)
checkEvaluationWithoutCodegen(lambda, internalRow.get(0,
schema(0).dataType), internalRow)
}
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerSuite.scala
index fb1ea7b..dd67a61 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerSuite.scala
@@ -60,8 +60,8 @@ class GenerateUnsafeRowJoinerSuite extends SparkFunSuite {
test("rows with all empty int arrays") {
val schema = StructType(Seq(
StructField("f1", ArrayType(IntegerType)), StructField("f2",
ArrayType(IntegerType))))
- val emptyIntArray =
-
ExpressionEncoder[Array[Int]]().resolveAndBind().toRow(Array.emptyIntArray).getArray(0)
+ val toRow =
ExpressionEncoder[Array[Int]]().resolveAndBind().createSerializer()
+ val emptyIntArray = toRow(Array.emptyIntArray).getArray(0)
val row: UnsafeRow = UnsafeProjection.create(schema).apply(
InternalRow(emptyIntArray, emptyIntArray))
testConcat(schema, row, schema, row)
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayDataIndexedSeqSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayDataIndexedSeqSuite.scala
index da71e3a..1e43035 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayDataIndexedSeqSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayDataIndexedSeqSuite.scala
@@ -73,8 +73,8 @@ class ArrayDataIndexedSeqSuite extends SparkFunSuite {
arrayTypes.foreach { dt =>
val schema = StructType(StructField("col_1", dt, nullable = false) ::
Nil)
val row = RandomDataGenerator.randomRow(random, schema)
- val rowConverter = RowEncoder(schema)
- val internalRow = rowConverter.toRow(row)
+ val toRow = RowEncoder(schema).createSerializer()
+ val internalRow = toRow(row)
val unsafeRowConverter = UnsafeProjection.create(schema)
val safeRowConverter = SafeProjection.create(schema)
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/UnsafeArraySuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/UnsafeArraySuite.scala
index e7b1c08..6d8ef68 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/UnsafeArraySuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/UnsafeArraySuite.scala
@@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst.util
import java.time.{ZoneId, ZoneOffset}
+import scala.reflect.runtime.universe.TypeTag
+
import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.serializer.{JavaSerializer, KryoSerializer}
import org.apache.spark.sql.Row
@@ -70,75 +72,55 @@ class UnsafeArraySuite extends SparkFunSuite {
arrayData
}
+ private def toUnsafeArray[T : TypeTag](array: Array[T]): ArrayData = {
+ val converted =
ExpressionEncoder[Array[T]].createSerializer().apply(array).getArray(0)
+ assert(converted.isInstanceOf[T])
+ assert(converted.numElements == array.length)
+ converted
+ }
+
test("read array") {
- val unsafeBoolean = ExpressionEncoder[Array[Boolean]].resolveAndBind().
- toRow(booleanArray).getArray(0)
- assert(unsafeBoolean.isInstanceOf[UnsafeArrayData])
- assert(unsafeBoolean.numElements == booleanArray.length)
+ val unsafeBoolean = toUnsafeArray(booleanArray)
booleanArray.zipWithIndex.map { case (e, i) =>
assert(unsafeBoolean.getBoolean(i) == e)
}
- val unsafeShort = ExpressionEncoder[Array[Short]].resolveAndBind().
- toRow(shortArray).getArray(0)
- assert(unsafeShort.isInstanceOf[UnsafeArrayData])
- assert(unsafeShort.numElements == shortArray.length)
+ val unsafeShort = toUnsafeArray(shortArray)
shortArray.zipWithIndex.map { case (e, i) =>
assert(unsafeShort.getShort(i) == e)
}
- val unsafeInt = ExpressionEncoder[Array[Int]].resolveAndBind().
- toRow(intArray).getArray(0)
- assert(unsafeInt.isInstanceOf[UnsafeArrayData])
- assert(unsafeInt.numElements == intArray.length)
+ val unsafeInt = toUnsafeArray(intArray)
intArray.zipWithIndex.map { case (e, i) =>
assert(unsafeInt.getInt(i) == e)
}
- val unsafeLong = ExpressionEncoder[Array[Long]].resolveAndBind().
- toRow(longArray).getArray(0)
- assert(unsafeLong.isInstanceOf[UnsafeArrayData])
- assert(unsafeLong.numElements == longArray.length)
+ val unsafeLong = toUnsafeArray(longArray)
longArray.zipWithIndex.map { case (e, i) =>
assert(unsafeLong.getLong(i) == e)
}
- val unsafeFloat = ExpressionEncoder[Array[Float]].resolveAndBind().
- toRow(floatArray).getArray(0)
- assert(unsafeFloat.isInstanceOf[UnsafeArrayData])
- assert(unsafeFloat.numElements == floatArray.length)
+ val unsafeFloat = toUnsafeArray(floatArray)
floatArray.zipWithIndex.map { case (e, i) =>
assert(unsafeFloat.getFloat(i) == e)
}
- val unsafeDouble = ExpressionEncoder[Array[Double]].resolveAndBind().
- toRow(doubleArray).getArray(0)
- assert(unsafeDouble.isInstanceOf[UnsafeArrayData])
- assert(unsafeDouble.numElements == doubleArray.length)
+ val unsafeDouble = toUnsafeArray(doubleArray)
doubleArray.zipWithIndex.map { case (e, i) =>
assert(unsafeDouble.getDouble(i) == e)
}
- val unsafeString = ExpressionEncoder[Array[String]].resolveAndBind().
- toRow(stringArray).getArray(0)
- assert(unsafeString.isInstanceOf[UnsafeArrayData])
- assert(unsafeString.numElements == stringArray.length)
+ val unsafeString = toUnsafeArray(stringArray)
stringArray.zipWithIndex.map { case (e, i) =>
assert(unsafeString.getUTF8String(i).toString().equals(e))
}
- val unsafeDate = ExpressionEncoder[Array[Int]].resolveAndBind().
- toRow(dateArray).getArray(0)
- assert(unsafeDate.isInstanceOf[UnsafeArrayData])
- assert(unsafeDate.numElements == dateArray.length)
+ val unsafeDate = toUnsafeArray(dateArray)
dateArray.zipWithIndex.map { case (e, i) =>
assert(unsafeDate.get(i, DateType).asInstanceOf[Int] == e)
}
- val unsafeTimestamp = ExpressionEncoder[Array[Long]].resolveAndBind().
- toRow(timestampArray).getArray(0)
- assert(unsafeTimestamp.isInstanceOf[UnsafeArrayData])
- assert(unsafeTimestamp.numElements == timestampArray.length)
+ val unsafeTimestamp = toUnsafeArray(timestampArray)
timestampArray.zipWithIndex.map { case (e, i) =>
assert(unsafeTimestamp.get(i, TimestampType).asInstanceOf[Long] == e)
}
@@ -149,7 +131,7 @@ class UnsafeArraySuite extends SparkFunSuite {
"array", ArrayType(DecimalType(decimal.precision, decimal.scale)))
val encoder = RowEncoder(schema).resolveAndBind()
val externalRow = Row(decimalArray)
- val ir = encoder.toRow(externalRow)
+ val ir = encoder.createSerializer().apply(externalRow)
val unsafeDecimal = ir.getArray(0)
assert(unsafeDecimal.isInstanceOf[UnsafeArrayData])
@@ -162,7 +144,7 @@ class UnsafeArraySuite extends SparkFunSuite {
val schema = new StructType().add("array", ArrayType(CalendarIntervalType))
val encoder = RowEncoder(schema).resolveAndBind()
val externalRow = Row(calenderintervalArray)
- val ir = encoder.toRow(externalRow)
+ val ir = encoder.createSerializer().apply(externalRow)
val unsafeCalendar = ir.getArray(0)
assert(unsafeCalendar.isInstanceOf[UnsafeArrayData])
assert(unsafeCalendar.numElements == calenderintervalArray.length)
@@ -170,10 +152,7 @@ class UnsafeArraySuite extends SparkFunSuite {
assert(unsafeCalendar.getInterval(i) == e)
}
- val unsafeMultiDimInt =
ExpressionEncoder[Array[Array[Int]]].resolveAndBind().
- toRow(intMultiDimArray).getArray(0)
- assert(unsafeMultiDimInt.isInstanceOf[UnsafeArrayData])
- assert(unsafeMultiDimInt.numElements == intMultiDimArray.length)
+ val unsafeMultiDimInt = toUnsafeArray(intMultiDimArray)
intMultiDimArray.zipWithIndex.map { case (a, j) =>
val u = unsafeMultiDimInt.getArray(j)
assert(u.isInstanceOf[UnsafeArrayData])
@@ -183,10 +162,7 @@ class UnsafeArraySuite extends SparkFunSuite {
}
}
- val unsafeMultiDimDouble =
ExpressionEncoder[Array[Array[Double]]].resolveAndBind().
- toRow(doubleMultiDimArray).getArray(0)
- assert(unsafeDouble.isInstanceOf[UnsafeArrayData])
- assert(unsafeMultiDimDouble.numElements == doubleMultiDimArray.length)
+ val unsafeMultiDimDouble = toUnsafeArray(doubleMultiDimArray)
doubleMultiDimArray.zipWithIndex.map { case (a, j) =>
val u = unsafeMultiDimDouble.getArray(j)
assert(u.isInstanceOf[UnsafeArrayData])
@@ -216,11 +192,9 @@ class UnsafeArraySuite extends SparkFunSuite {
}
test("to primitive array") {
- val intEncoder = ExpressionEncoder[Array[Int]].resolveAndBind()
-
assert(intEncoder.toRow(intArray).getArray(0).toIntArray.sameElements(intArray))
+ assert(toUnsafeArray(intArray).toIntArray().sameElements(intArray))
- val doubleEncoder = ExpressionEncoder[Array[Double]].resolveAndBind()
-
assert(doubleEncoder.toRow(doubleArray).getArray(0).toDoubleArray.sameElements(doubleArray))
+
assert(toUnsafeArray(doubleArray).toDoubleArray().sameElements(doubleArray))
}
test("unsafe java serialization") {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 5c3a82a..c17535a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -2965,9 +2965,8 @@ class Dataset[T] private[sql](
*/
def toLocalIterator(): java.util.Iterator[T] = {
withAction("toLocalIterator", queryExecution) { plan =>
- // `ExpressionEncoder` is not thread-safe, here we create a new encoder.
- val enc = resolvedEnc.copy()
- plan.executeToIterator().map(enc.fromRow).asJava
+ val fromRow = resolvedEnc.createDeserializer()
+ plan.executeToIterator().map(fromRow).asJava
}
}
@@ -3387,9 +3386,10 @@ class Dataset[T] private[sql](
new JSONOptions(Map.empty[String, String], sessionLocalTimeZone))
new Iterator[String] {
+ private val toRow = exprEnc.createSerializer()
override def hasNext: Boolean = iter.hasNext
override def next(): String = {
- gen.write(exprEnc.toRow(iter.next()))
+ gen.write(toRow(iter.next()))
gen.flush()
val json = writer.toString
@@ -3621,9 +3621,8 @@ class Dataset[T] private[sql](
* Collect all elements from a spark plan.
*/
private def collectFromPlan(plan: SparkPlan): Array[T] = {
- // `ExpressionEncoder` is not thread-safe, here we create a new encoder.
- val enc = resolvedEnc.copy()
- plan.executeCollect().map(enc.fromRow)
+ val fromRow = resolvedEnc.createDeserializer()
+ plan.executeCollect().map(fromRow)
}
private def sortInternal(global: Boolean, sortExprs: Seq[Column]):
Dataset[T] = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
index bca841c..731aae8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -345,7 +345,8 @@ class SparkSession private(
// TODO: use MutableProjection when rowRDD is another DataFrame and the
applied
// schema differs from the existing schema on any field data type.
val encoder = RowEncoder(schema)
- val catalystRows = rowRDD.map(encoder.toRow)
+ val toRow = encoder.createSerializer()
+ val catalystRows = rowRDD.map(toRow)
internalCreateDataFrame(catalystRows.setName(rowRDD.name), schema)
}
@@ -459,10 +460,10 @@ class SparkSession private(
* @since 2.0.0
*/
def createDataset[T : Encoder](data: Seq[T]): Dataset[T] = {
- // `ExpressionEncoder` is not thread-safe, here we create a new encoder.
- val enc = encoderFor[T].copy()
+ val enc = encoderFor[T]
+ val toRow = enc.createSerializer()
val attributes = enc.schema.toAttributes
- val encoded = data.map(d => enc.toRow(d).copy())
+ val encoded = data.map(d => toRow(d).copy())
val plan = new LocalRelation(attributes, encoded)
Dataset[T](self, plan)
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index bd2684d..12a1a1e 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -656,7 +656,8 @@ abstract class SparkStrategies extends
QueryPlanner[SparkPlan] {
case MemoryPlan(sink, output) =>
val encoder = RowEncoder(StructType.fromAttributes(output))
- LocalTableScanExec(output, sink.allData.map(r =>
encoder.toRow(r).copy())) :: Nil
+ val toRow = encoder.createSerializer()
+ LocalTableScanExec(output, sink.allData.map(r => toRow(r).copy())) ::
Nil
case logical.Distinct(child) =>
throw new IllegalStateException(
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
index dfae5c0..544b90a 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
@@ -469,14 +469,17 @@ case class ScalaAggregator[IN, BUF, OUT](
with ImplicitCastInputTypes
with Logging {
- private[this] lazy val inputEncoder = inputEncoderNR.resolveAndBind()
+ private[this] lazy val inputDeserializer =
inputEncoderNR.resolveAndBind().createDeserializer()
private[this] lazy val bufferEncoder =
agg.bufferEncoder.asInstanceOf[ExpressionEncoder[BUF]].resolveAndBind()
+ private[this] lazy val bufferSerializer = bufferEncoder.createSerializer()
+ private[this] lazy val bufferDeserializer =
bufferEncoder.createDeserializer()
private[this] lazy val outputEncoder =
agg.outputEncoder.asInstanceOf[ExpressionEncoder[OUT]]
+ private[this] lazy val outputSerializer = outputEncoder.createSerializer()
def dataType: DataType = outputEncoder.objSerializer.dataType
- def inputTypes: Seq[DataType] = inputEncoder.schema.map(_.dataType)
+ def inputTypes: Seq[DataType] = inputEncoderNR.schema.map(_.dataType)
override lazy val deterministic: Boolean = isDeterministic
@@ -491,23 +494,23 @@ case class ScalaAggregator[IN, BUF, OUT](
def createAggregationBuffer(): BUF = agg.zero
def update(buffer: BUF, input: InternalRow): BUF =
- agg.reduce(buffer, inputEncoder.fromRow(inputProjection(input)))
+ agg.reduce(buffer, inputDeserializer(inputProjection(input)))
def merge(buffer: BUF, input: BUF): BUF = agg.merge(buffer, input)
def eval(buffer: BUF): Any = {
- val row = outputEncoder.toRow(agg.finish(buffer))
+ val row = outputSerializer(agg.finish(buffer))
if (outputEncoder.isSerializedAsStruct) row else row.get(0, dataType)
}
private[this] lazy val bufferRow = new
UnsafeRow(bufferEncoder.namedExpressions.length)
def serialize(agg: BUF): Array[Byte] =
- bufferEncoder.toRow(agg).asInstanceOf[UnsafeRow].getBytes()
+ bufferSerializer(agg).asInstanceOf[UnsafeRow].getBytes()
def deserialize(storageFormat: Array[Byte]): BUF = {
bufferRow.pointTo(storageFormat, storageFormat.length)
- bufferEncoder.fromRow(bufferRow)
+ bufferDeserializer(bufferRow)
}
override def toString: String = s"""${nodeName}(${children.mkString(",")})"""
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
index faf3760..a58038d 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
@@ -637,9 +637,9 @@ object DataSourceStrategy {
output: Seq[Attribute],
rdd: RDD[Row]): RDD[InternalRow] = {
if (relation.needConversion) {
- val converters = RowEncoder(StructType.fromAttributes(output))
+ val toRow =
RowEncoder(StructType.fromAttributes(output)).createSerializer()
rdd.mapPartitions { iterator =>
- iterator.map(converters.toRow)
+ iterator.map(toRow)
}
} else {
rdd.asInstanceOf[RDD[InternalRow]]
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
index c1e1aed..3f44312 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
@@ -331,9 +331,9 @@ object JdbcUtils extends Logging {
def resultSetToRows(resultSet: ResultSet, schema: StructType): Iterator[Row]
= {
val inputMetrics =
Option(TaskContext.get()).map(_.taskMetrics().inputMetrics).getOrElse(new
InputMetrics)
- val encoder = RowEncoder(schema).resolveAndBind()
+ val fromRow = RowEncoder(schema).resolveAndBind().createDeserializer()
val internalRows = resultSetToSparkInternalRows(resultSet, schema,
inputMetrics)
- internalRows.map(encoder.fromRow)
+ internalRows.map(fromRow)
}
private[spark] def resultSetToSparkInternalRows(
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DescribeNamespaceExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DescribeNamespaceExec.scala
index 64b98fb..b4a14c6 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DescribeNamespaceExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DescribeNamespaceExec.scala
@@ -34,7 +34,9 @@ case class DescribeNamespaceExec(
catalog: SupportsNamespaces,
namespace: Seq[String],
isExtended: Boolean) extends V2CommandExec {
- private val encoder =
RowEncoder(StructType.fromAttributes(output)).resolveAndBind()
+ private val toRow = {
+
RowEncoder(StructType.fromAttributes(output)).resolveAndBind().createSerializer()
+ }
override protected def run(): Seq[InternalRow] = {
val rows = new ArrayBuffer[InternalRow]()
@@ -57,6 +59,6 @@ case class DescribeNamespaceExec(
}
private def toCatalystRow(strs: String*): InternalRow = {
- encoder.toRow(new GenericRowWithSchema(strs.toArray, schema)).copy()
+ toRow(new GenericRowWithSchema(strs.toArray, schema)).copy()
}
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DescribeTableExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DescribeTableExec.scala
index 9c28020..bc6bb17 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DescribeTableExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DescribeTableExec.scala
@@ -31,7 +31,9 @@ case class DescribeTableExec(
table: Table,
isExtended: Boolean) extends V2CommandExec {
- private val encoder =
RowEncoder(StructType.fromAttributes(output)).resolveAndBind()
+ private val toRow = {
+
RowEncoder(StructType.fromAttributes(output)).resolveAndBind().createSerializer()
+ }
override protected def run(): Seq[InternalRow] = {
val rows = new ArrayBuffer[InternalRow]()
@@ -85,6 +87,6 @@ case class DescribeTableExec(
private def emptyRow(): InternalRow = toCatalystRow("", "", "")
private def toCatalystRow(strs: String*): InternalRow = {
- encoder.toRow(new GenericRowWithSchema(strs.toArray, schema)).copy()
+ toRow(new GenericRowWithSchema(strs.toArray, schema)).copy()
}
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ShowCurrentNamespaceExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ShowCurrentNamespaceExec.scala
index 42b80a1..5f7b6f4 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ShowCurrentNamespaceExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ShowCurrentNamespaceExec.scala
@@ -31,10 +31,11 @@ case class ShowCurrentNamespaceExec(
catalogManager: CatalogManager)
extends V2CommandExec {
override protected def run(): Seq[InternalRow] = {
- val encoder = RowEncoder(schema).resolveAndBind()
- Seq(encoder
- .toRow(new GenericRowWithSchema(
- Array(catalogManager.currentCatalog.name,
catalogManager.currentNamespace.quoted), schema))
- .copy())
+ val toRow = RowEncoder(schema).resolveAndBind().createSerializer()
+ val result = new GenericRowWithSchema(Array[Any](
+ catalogManager.currentCatalog.name,
+ catalogManager.currentNamespace.quoted),
+ schema)
+ Seq(toRow(result).copy())
}
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ShowNamespacesExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ShowNamespacesExec.scala
index 6f96848..9188f4e 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ShowNamespacesExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ShowNamespacesExec.scala
@@ -44,13 +44,11 @@ case class ShowNamespacesExec(
}
val rows = new ArrayBuffer[InternalRow]()
- val encoder = RowEncoder(schema).resolveAndBind()
+ val toRow = RowEncoder(schema).resolveAndBind().createSerializer()
namespaces.map(_.quoted).map { ns =>
if (pattern.map(StringUtils.filterPattern(Seq(ns),
_).nonEmpty).getOrElse(true)) {
- rows += encoder
- .toRow(new GenericRowWithSchema(Array(ns), schema))
- .copy()
+ rows += toRow(new GenericRowWithSchema(Array(ns), schema)).copy()
}
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ShowTablePropertiesExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ShowTablePropertiesExec.scala
index 7905c35..0bcd7ea 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ShowTablePropertiesExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ShowTablePropertiesExec.scala
@@ -32,17 +32,17 @@ case class ShowTablePropertiesExec(
override protected def run(): Seq[InternalRow] = {
import scala.collection.JavaConverters._
- val encoder = RowEncoder(schema).resolveAndBind()
+ val toRow = RowEncoder(schema).resolveAndBind().createSerializer()
val properties = catalogTable.properties.asScala
propertyKey match {
case Some(p) =>
val propValue = properties
.getOrElse(p, s"Table ${catalogTable.name} does not have property:
$p")
- Seq(encoder.toRow(new GenericRowWithSchema(Array(p, propValue),
schema)).copy())
+ Seq(toRow(new GenericRowWithSchema(Array(p, propValue),
schema)).copy())
case None =>
properties.keys.map(k =>
- encoder.toRow(new GenericRowWithSchema(Array(k, properties(k)),
schema)).copy()).toSeq
+ toRow(new GenericRowWithSchema(Array(k, properties(k)),
schema)).copy()).toSeq
}
}
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ShowTablesExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ShowTablesExec.scala
index c740e0d..820f5ae 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ShowTablesExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ShowTablesExec.scala
@@ -37,17 +37,15 @@ case class ShowTablesExec(
pattern: Option[String]) extends V2CommandExec with LeafExecNode {
override protected def run(): Seq[InternalRow] = {
val rows = new ArrayBuffer[InternalRow]()
- val encoder = RowEncoder(schema).resolveAndBind()
+ val toRow = RowEncoder(schema).resolveAndBind().createSerializer()
val tables = catalog.listTables(namespace.toArray)
tables.map { table =>
if (pattern.map(StringUtils.filterPattern(Seq(table.name()),
_).nonEmpty).getOrElse(true)) {
- rows += encoder
- .toRow(
- new GenericRowWithSchema(
- Array(table.namespace().quoted, table.name()),
- schema))
- .copy()
+ val result = new GenericRowWithSchema(
+ Array(table.namespace().quoted, table.name()),
+ schema)
+ rows += toRow(result).copy()
}
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTextSocketSource.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTextSocketSource.scala
index fc47c5e..368dfae 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTextSocketSource.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTextSocketSource.scala
@@ -173,6 +173,7 @@ class TextSocketContinuousStream(
setDaemon(true)
override def run(): Unit = {
+ val toRow = encoder.createSerializer()
try {
while (true) {
val line = reader.readLine()
@@ -187,7 +188,7 @@ class TextSocketContinuousStream(
Timestamp.valueOf(
TextSocketReader.DATE_FORMAT.format(Calendar.getInstance().getTime()))
)
- buckets(currentOffset % numPartitions) += encoder.toRow(newData)
+ buckets(currentOffset % numPartitions) += toRow(newData)
.copy().asInstanceOf[UnsafeRow]
}
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
index ea39c54..e5b9e68 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
@@ -27,7 +27,7 @@ import scala.collection.mutable.ListBuffer
import org.apache.spark.internal.Logging
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.encoders.encoderFor
+import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder}
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.util.truncatedString
@@ -57,6 +57,8 @@ abstract class MemoryStreamBase[A : Encoder](sqlContext:
SQLContext) extends Spa
val encoder = encoderFor[A]
protected val attributes = encoder.schema.toAttributes
+ protected lazy val toRow: ExpressionEncoder.Serializer[A] =
encoder.createSerializer()
+
def toDS(): Dataset[A] = {
Dataset[A](sqlContext.sparkSession, logicalPlan)
}
@@ -176,7 +178,7 @@ case class MemoryStream[A : Encoder](
def addData(data: TraversableOnce[A]): Offset = {
val objects = data.toSeq
- val rows = objects.iterator.map(d =>
encoder.toRow(d).copy().asInstanceOf[UnsafeRow]).toArray
+ val rows = objects.iterator.map(d =>
toRow(d).copy().asInstanceOf[UnsafeRow]).toArray
logDebug(s"Adding: $objects")
this.synchronized {
currentOffset = currentOffset + 1
@@ -243,7 +245,7 @@ case class MemoryStream[A : Encoder](
rows: Seq[UnsafeRow],
startOrdinal: Int,
endOrdinal: Int): String = {
- val fromRow = encoder.resolveAndBind().fromRow _
+ val fromRow = encoder.resolveAndBind().createDeserializer()
s"MemoryBatch [$startOrdinal, $endOrdinal]: " +
s"${rows.map(row => fromRow(row)).mkString(", ")}"
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala
index f944693..d0cf602 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala
@@ -60,7 +60,7 @@ class ContinuousMemoryStream[A : Encoder](id: Int,
sqlContext: SQLContext, numPa
// Distribute data evenly among partition lists.
data.toSeq.zipWithIndex.map {
case (item, index) =>
- records(index % numPartitions) +=
encoder.toRow(item).copy().asInstanceOf[UnsafeRow]
+ records(index % numPartitions) +=
toRow(item).copy().asInstanceOf[UnsafeRow]
}
// The new target offset is the offset where all records in all partitions
have been processed.
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSink.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSink.scala
index 03c567c..6d5e7fd 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSink.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSink.scala
@@ -30,7 +30,8 @@ class ForeachBatchSink[T](batchWriter: (Dataset[T], Long) =>
Unit, encoder: Expr
val resolvedEncoder = encoder.resolveAndBind(
data.logicalPlan.output,
data.sparkSession.sessionState.analyzer)
- val rdd =
data.queryExecution.toRdd.map[T](resolvedEncoder.fromRow)(encoder.clsTag)
+ val fromRow = resolvedEncoder.createDeserializer()
+ val rdd = data.queryExecution.toRdd.map[T](fromRow)(encoder.clsTag)
val ds = data.sparkSession.createDataset(rdd)(encoder)
batchWriter(ds, batchId)
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterTable.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterTable.scala
index 6e4f40a..ba54c85 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterTable.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterTable.scala
@@ -73,7 +73,7 @@ case class ForeachWriterTable[T](
val boundEnc = enc.resolveAndBind(
inputSchema.toAttributes,
SparkSession.getActiveSession.get.sessionState.analyzer)
- boundEnc.fromRow
+ boundEnc.createDeserializer()
case Right(func) =>
func
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memory.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memory.scala
index 2b67407..deab42b 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memory.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memory.scala
@@ -172,10 +172,10 @@ class MemoryDataWriter(partition: Int, schema: StructType)
private val data = mutable.Buffer[Row]()
- private val encoder = RowEncoder(schema).resolveAndBind()
+ private val fromRow =
RowEncoder(schema).resolveAndBind().createDeserializer()
override def write(row: InternalRow): Unit = {
- data.append(encoder.fromRow(row))
+ data.append(fromRow(row))
}
override def commit(): MemoryWriterCommitMessage = {
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala
b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala
index d3ef03e..7ca9fbb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala
@@ -517,7 +517,8 @@ private[sql] object CatalogImpl {
data: Seq[T],
sparkSession: SparkSession): Dataset[T] = {
val enc = ExpressionEncoder[T]()
- val encoded = data.map(d => enc.toRow(d).copy())
+ val toRow = enc.createSerializer()
+ val encoded = data.map(d => toRow(d).copy())
val plan = new LocalRelation(enc.schema.toAttributes, encoded)
val queryExecution = sparkSession.sessionState.executePlan(plan)
new Dataset[T](queryExecution, enc)
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/GroupedIteratorSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/GroupedIteratorSuite.scala
index 80340b5..4b2a2b4 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/GroupedIteratorSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/GroupedIteratorSuite.scala
@@ -28,14 +28,16 @@ class GroupedIteratorSuite extends SparkFunSuite {
test("basic") {
val schema = new StructType().add("i", IntegerType).add("s", StringType)
val encoder = RowEncoder(schema).resolveAndBind()
+ val toRow = encoder.createSerializer()
+ val fromRow = encoder.createDeserializer()
val input = Seq(Row(1, "a"), Row(1, "b"), Row(2, "c"))
- val grouped = GroupedIterator(input.iterator.map(encoder.toRow),
+ val grouped = GroupedIterator(input.iterator.map(toRow),
Seq('i.int.at(0)), schema.toAttributes)
val result = grouped.map {
case (key, data) =>
assert(key.numFields == 1)
- key.getInt(0) -> data.map(encoder.fromRow).toSeq
+ key.getInt(0) -> data.map(fromRow).toSeq
}.toSeq
assert(result ==
@@ -46,6 +48,8 @@ class GroupedIteratorSuite extends SparkFunSuite {
test("group by 2 columns") {
val schema = new StructType().add("i", IntegerType).add("l",
LongType).add("s", StringType)
val encoder = RowEncoder(schema).resolveAndBind()
+ val toRow = encoder.createSerializer()
+ val fromRow = encoder.createDeserializer()
val input = Seq(
Row(1, 2L, "a"),
@@ -54,13 +58,13 @@ class GroupedIteratorSuite extends SparkFunSuite {
Row(2, 1L, "d"),
Row(3, 2L, "e"))
- val grouped = GroupedIterator(input.iterator.map(encoder.toRow),
+ val grouped = GroupedIterator(input.iterator.map(toRow),
Seq('i.int.at(0), 'l.long.at(1)), schema.toAttributes)
val result = grouped.map {
case (key, data) =>
assert(key.numFields == 2)
- (key.getInt(0), key.getLong(1), data.map(encoder.fromRow).toSeq)
+ (key.getInt(0), key.getLong(1), data.map(fromRow).toSeq)
}.toSeq
assert(result ==
@@ -73,8 +77,9 @@ class GroupedIteratorSuite extends SparkFunSuite {
test("do nothing to the value iterator") {
val schema = new StructType().add("i", IntegerType).add("s", StringType)
val encoder = RowEncoder(schema).resolveAndBind()
+ val toRow = encoder.createSerializer()
val input = Seq(Row(1, "a"), Row(1, "b"), Row(2, "c"))
- val grouped = GroupedIterator(input.iterator.map(encoder.toRow),
+ val grouped = GroupedIterator(input.iterator.map(toRow),
Seq('i.int.at(0)), schema.toAttributes)
assert(grouped.length == 2)
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/UnsafeArrayDataBenchmark.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/UnsafeArrayDataBenchmark.scala
index f582d84..9b0389c 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/UnsafeArrayDataBenchmark.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/UnsafeArrayDataBenchmark.scala
@@ -40,13 +40,16 @@ object UnsafeArrayDataBenchmark extends BenchmarkBase {
UnsafeArrayData.calculateHeaderPortionInBytes(count)
}
+ private lazy val intEncoder =
ExpressionEncoder[Array[Int]]().resolveAndBind()
+
+ private lazy val doubleEncoder =
ExpressionEncoder[Array[Double]]().resolveAndBind()
+
def readUnsafeArray(iters: Int): Unit = {
val count = 1024 * 1024 * 16
val rand = new Random(42)
-
+ val intArrayToRow = intEncoder.createSerializer()
val intPrimitiveArray = Array.fill[Int](count) { rand.nextInt }
- val intEncoder = ExpressionEncoder[Array[Int]].resolveAndBind()
- val intUnsafeArray = intEncoder.toRow(intPrimitiveArray).getArray(0)
+ val intUnsafeArray = intArrayToRow(intPrimitiveArray).getArray(0)
val readIntArray = { i: Int =>
var n = 0
while (n < iters) {
@@ -62,8 +65,8 @@ object UnsafeArrayDataBenchmark extends BenchmarkBase {
}
val doublePrimitiveArray = Array.fill[Double](count) { rand.nextDouble }
- val doubleEncoder = ExpressionEncoder[Array[Double]].resolveAndBind()
- val doubleUnsafeArray =
doubleEncoder.toRow(doublePrimitiveArray).getArray(0)
+ val doubleArrayToRow = doubleEncoder.createSerializer()
+ val doubleUnsafeArray = doubleArrayToRow(doublePrimitiveArray).getArray(0)
val readDoubleArray = { i: Int =>
var n = 0
while (n < iters) {
@@ -90,12 +93,12 @@ object UnsafeArrayDataBenchmark extends BenchmarkBase {
var intTotalLength: Int = 0
val intPrimitiveArray = Array.fill[Int](count) { rand.nextInt }
- val intEncoder = ExpressionEncoder[Array[Int]].resolveAndBind()
+ val intArrayToRow = intEncoder.createSerializer()
val writeIntArray = { i: Int =>
var len = 0
var n = 0
while (n < iters) {
- len += intEncoder.toRow(intPrimitiveArray).getArray(0).numElements()
+ len += intArrayToRow(intPrimitiveArray).getArray(0).numElements()
n += 1
}
intTotalLength = len
@@ -103,12 +106,12 @@ object UnsafeArrayDataBenchmark extends BenchmarkBase {
var doubleTotalLength: Int = 0
val doublePrimitiveArray = Array.fill[Double](count) { rand.nextDouble }
- val doubleEncoder = ExpressionEncoder[Array[Double]].resolveAndBind()
+ val doubleArrayToRow = doubleEncoder.createSerializer()
val writeDoubleArray = { i: Int =>
var len = 0
var n = 0
while (n < iters) {
- len +=
doubleEncoder.toRow(doublePrimitiveArray).getArray(0).numElements()
+ len += doubleArrayToRow(doublePrimitiveArray).getArray(0).numElements()
n += 1
}
doubleTotalLength = len
@@ -126,8 +129,8 @@ object UnsafeArrayDataBenchmark extends BenchmarkBase {
var intTotalLength: Int = 0
val intPrimitiveArray = Array.fill[Int](count) { rand.nextInt }
- val intEncoder = ExpressionEncoder[Array[Int]].resolveAndBind()
- val intUnsafeArray = intEncoder.toRow(intPrimitiveArray).getArray(0)
+ val intArrayToRow = intEncoder.createSerializer()
+ val intUnsafeArray = intArrayToRow(intPrimitiveArray).getArray(0)
val readIntArray = { i: Int =>
var len = 0
var n = 0
@@ -140,8 +143,8 @@ object UnsafeArrayDataBenchmark extends BenchmarkBase {
var doubleTotalLength: Int = 0
val doublePrimitiveArray = Array.fill[Double](count) { rand.nextDouble }
- val doubleEncoder = ExpressionEncoder[Array[Double]].resolveAndBind()
- val doubleUnsafeArray =
doubleEncoder.toRow(doublePrimitiveArray).getArray(0)
+ val doubleArrayToRow = doubleEncoder.createSerializer()
+ val doubleUnsafeArray = doubleArrayToRow(doublePrimitiveArray).getArray(0)
val readDoubleArray = { i: Int =>
var len = 0
var n = 0
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormatSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormatSuite.scala
index 2cd142f..8462916 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormatSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormatSuite.scala
@@ -304,7 +304,7 @@ class BinaryFileFormatSuite extends QueryTest with
SharedSparkSession {
val partitionedFile = mock(classOf[PartitionedFile])
when(partitionedFile.filePath).thenReturn(file.getPath)
val encoder = RowEncoder(requiredSchema).resolveAndBind()
- encoder.fromRow(reader(partitionedFile).next())
+ encoder.createDeserializer().apply(reader(partitionedFile).next())
}
test("column pruning") {
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
index 6d5ad87..8d54395 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
@@ -144,16 +144,22 @@ trait StreamTest extends QueryTest with
SharedSparkSession with TimeLimits with
}
}
+ private def createToExternalRowConverter[A : Encoder](): A => Row = {
+ val encoder = encoderFor[A]
+ val toInternalRow = encoder.createSerializer()
+ val toExternalRow =
RowEncoder(encoder.schema).resolveAndBind().createDeserializer()
+ toExternalRow.compose(toInternalRow)
+ }
+
/**
* Checks to make sure that the current data stored in the sink matches the
`expectedAnswer`.
* This operation automatically blocks until all added data has been
processed.
*/
object CheckAnswer {
def apply[A : Encoder](data: A*): CheckAnswerRows = {
- val encoder = encoderFor[A]
- val toExternalRow = RowEncoder(encoder.schema).resolveAndBind()
+ val toExternalRow = createToExternalRowConverter[A]()
CheckAnswerRows(
- data.map(d => toExternalRow.fromRow(encoder.toRow(d))),
+ data.map(toExternalRow),
lastOnly = false,
isSorted = false)
}
@@ -174,10 +180,9 @@ trait StreamTest extends QueryTest with SharedSparkSession
with TimeLimits with
}
def apply[A: Encoder](isSorted: Boolean, data: A*): CheckAnswerRows = {
- val encoder = encoderFor[A]
- val toExternalRow = RowEncoder(encoder.schema).resolveAndBind()
+ val toExternalRow = createToExternalRowConverter[A]()
CheckAnswerRows(
- data.map(d => toExternalRow.fromRow(encoder.toRow(d))),
+ data.map(toExternalRow),
lastOnly = true,
isSorted = isSorted)
}
@@ -215,9 +220,8 @@ trait StreamTest extends QueryTest with SharedSparkSession
with TimeLimits with
def apply(): CheckNewAnswerRows = CheckNewAnswerRows(Seq.empty)
def apply[A: Encoder](data: A, moreData: A*): CheckNewAnswerRows = {
- val encoder = encoderFor[A]
- val toExternalRow = RowEncoder(encoder.schema).resolveAndBind()
- CheckNewAnswerRows((data +: moreData).map(d =>
toExternalRow.fromRow(encoder.toRow(d))))
+ val toExternalRow = createToExternalRowConverter[A]()
+ CheckNewAnswerRows((data +: moreData).map(toExternalRow))
}
def apply(rows: Row*): CheckNewAnswerRows = CheckNewAnswerRows(rows)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]