This is an automated email from the ASF dual-hosted git repository. hvanhovell pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 62338ed6cd9 [SPARK-42626][CONNECT] Add Destructive Iterator for SparkResult 62338ed6cd9 is described below commit 62338ed6cd9fba8bb92ec11cea643077e4b69db4 Author: Tengfei Huang <tengfe...@gmail.com> AuthorDate: Mon Jun 5 21:30:02 2023 -0400 [SPARK-42626][CONNECT] Add Destructive Iterator for SparkResult ### What changes were proposed in this pull request? Add a destructive iterator to SparkResult and change `Dataset.toLocalIterator` to use the desctructive iterator. With the desctructive iterator, we will: 1. Close the `ColumarBatch` once its data got consumed; 2. Remove the `ColumarBatch` from `SparkResult.batches`; ### Why are the changes needed? Instead of keeping everything in memory for the life time of SparkResult object, clean it up as soon as we know we are done with it. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? UT added. Closes #40610 from ivoson/SPARK-42626. Authored-by: Tengfei Huang <tengfe...@gmail.com> Signed-off-by: Herman van Hovell <her...@databricks.com> --- .../main/scala/org/apache/spark/sql/Dataset.scala | 3 +- .../spark/sql/connect/client/SparkResult.scala | 43 ++++++++++++------ .../org/apache/spark/sql/ClientE2ETestSuite.scala | 52 +++++++++++++++++++++- 3 files changed, 81 insertions(+), 17 deletions(-) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala index 7a680bde7d3..eba425ce127 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2768,8 +2768,7 @@ class Dataset[T] private[sql] ( * @since 3.4.0 */ def toLocalIterator(): java.util.Iterator[T] = { - // TODO make this a destructive iterator. - collectResult().iterator + collectResult().destructiveIterator } /** diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala index 49db44bd855..86a7cf846f2 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala @@ -46,7 +46,8 @@ private[sql] class SparkResult[T]( private[this] var numRecords: Int = 0 private[this] var structType: StructType = _ private[this] var boundEncoder: ExpressionEncoder[T] = _ - private[this] val batches = mutable.Buffer.empty[ColumnarBatch] + private[this] var nextBatchIndex: Int = 0 + private val idxToBatches = mutable.Map.empty[Int, ColumnarBatch] private def createEncoder(schema: StructType): ExpressionEncoder[T] = { val agnosticEncoder = if (encoder == UnboundRowEncoder) { @@ -70,12 +71,12 @@ private[sql] class SparkResult[T]( val reader = new ArrowStreamReader(ipcStreamBytes.newInput(), allocator) try { val root = reader.getVectorSchemaRoot - if (batches.isEmpty) { - if (structType == null) { - // If the schema is not available yet, fallback to the schema from Arrow. - structType = ArrowUtils.fromArrowSchema(root.getSchema) - } - // TODO: create encoders that directly operate on arrow vectors. + if (structType == null) { + // If the schema is not available yet, fallback to the schema from Arrow. + structType = ArrowUtils.fromArrowSchema(root.getSchema) + } + // TODO: create encoders that directly operate on arrow vectors. + if (boundEncoder == null) { boundEncoder = createEncoder(structType).resolveAndBind(structType.toAttributes) } while (reader.loadNextBatch()) { @@ -85,7 +86,8 @@ private[sql] class SparkResult[T]( val vectors = root.getFieldVectors.asScala .map(v => new ArrowColumnVector(transferToNewVector(v))) .toArray[ColumnVector] - batches += new ColumnarBatch(vectors, rowCount) + idxToBatches.put(nextBatchIndex, new ColumnarBatch(vectors, rowCount)) + nextBatchIndex += 1 numRecords += rowCount if (stopOnFirstNonEmptyResponse) { return true @@ -142,24 +144,39 @@ private[sql] class SparkResult[T]( /** * Returns an iterator over the contents of the result. */ - def iterator: java.util.Iterator[T] with AutoCloseable = { + def iterator: java.util.Iterator[T] with AutoCloseable = + buildIterator(destructive = false) + + /** + * Returns an destructive iterator over the contents of the result. + */ + def destructiveIterator: java.util.Iterator[T] with AutoCloseable = + buildIterator(destructive = true) + + private def buildIterator(destructive: Boolean): java.util.Iterator[T] with AutoCloseable = { new java.util.Iterator[T] with AutoCloseable { private[this] var batchIndex: Int = -1 private[this] var iterator: java.util.Iterator[InternalRow] = Collections.emptyIterator() private[this] var deserializer: Deserializer[T] = _ + override def hasNext: Boolean = { if (iterator.hasNext) { return true } + val nextBatchIndex = batchIndex + 1 - val hasNextBatch = if (nextBatchIndex == batches.size) { + if (destructive) { + idxToBatches.remove(batchIndex).foreach(_.close()) + } + + val hasNextBatch = if (!idxToBatches.contains(nextBatchIndex)) { processResponses(stopOnFirstNonEmptyResponse = true) } else { true } if (hasNextBatch) { batchIndex = nextBatchIndex - iterator = batches(nextBatchIndex).rowIterator() + iterator = idxToBatches(nextBatchIndex).rowIterator() if (deserializer == null) { deserializer = boundEncoder.createDeserializer() } @@ -182,8 +199,8 @@ private[sql] class SparkResult[T]( * Close this result, freeing any underlying resources. */ override def close(): Unit = { - batches.foreach(_.close()) + idxToBatches.values.foreach(_.close()) } - override def cleaner: AutoCloseable = AutoCloseables(batches.toSeq) + override def cleaner: AutoCloseable = AutoCloseables(idxToBatches.values.toSeq) } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala index 1a775f55ff5..bdef6b92ece 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala @@ -21,6 +21,7 @@ import java.nio.file.Files import java.util.Properties import scala.collection.JavaConverters._ +import scala.collection.mutable import scala.concurrent.{ExecutionContext, Future} import scala.concurrent.duration._ import scala.util.{Failure, Success} @@ -30,21 +31,23 @@ import org.apache.commons.io.FileUtils import org.apache.commons.io.output.TeeOutputStream import org.apache.commons.lang3.{JavaVersion, SystemUtils} import org.scalactic.TolerantNumerics +import org.scalatest.PrivateMethodTester import org.scalatest.concurrent.Eventually._ import org.apache.spark.{SPARK_VERSION, SparkException} import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.StringEncoder import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema import org.apache.spark.sql.catalyst.parser.ParseException -import org.apache.spark.sql.connect.client.SparkConnectClient +import org.apache.spark.sql.connect.client.{SparkConnectClient, SparkResult} import org.apache.spark.sql.connect.client.util.{IntegrationTestUtils, RemoteSparkSession} import org.apache.spark.sql.connect.client.util.SparkConnectServerUtils.port import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.ThreadUtils -class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper { +class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper with PrivateMethodTester { // Spark Result test("spark result schema") { @@ -890,6 +893,51 @@ class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper { assert(message.contains("PARSE_SYNTAX_ERROR")) } + test("Dataset result destructive iterator") { + // Helper methods for accessing private field `idxToBatches` from SparkResult + val _idxToBatches = + PrivateMethod[mutable.Map[Int, ColumnarBatch]](Symbol("idxToBatches")) + + def getColumnarBatches(result: SparkResult[_]): Seq[ColumnarBatch] = { + val idxToBatches = result invokePrivate _idxToBatches() + + // Sort by key to get stable results. + idxToBatches.toSeq.sortBy(_._1).map(_._2) + } + + val df = spark + .range(0, 10, 1, 10) + .filter("id > 5 and id < 9") + + df.withResult { result => + try { + // build and verify the destructive iterator + val iterator = result.destructiveIterator + // batches is empty before traversing the result iterator + assert(getColumnarBatches(result).isEmpty) + var previousBatch: ColumnarBatch = null + val buffer = mutable.Buffer.empty[Long] + while (iterator.hasNext) { + // always having 1 batch, since a columnar batch will be removed and closed after + // its data got consumed. + val batches = getColumnarBatches(result) + assert(batches.size === 1) + assert(batches.head != previousBatch) + previousBatch = batches.head + + buffer.append(iterator.next()) + } + // Batches should be closed and removed after traversing all the records. + assert(getColumnarBatches(result).isEmpty) + + val expectedResult = Seq(6L, 7L, 8L) + assert(buffer.size === 3 && expectedResult.forall(buffer.contains)) + } finally { + result.close() + } + } + } + test("SparkSession.createDataFrame - large data set") { val threshold = 1024 * 1024 withSQLConf(SQLConf.LOCAL_RELATION_CACHE_THRESHOLD.key -> threshold.toString) { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org