Github user davies commented on a diff in the pull request:

    https://github.com/apache/spark/pull/15089#discussion_r79253457
  
    --- Diff: 
sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala
 ---
    @@ -17,18 +17,270 @@
     
     package org.apache.spark.sql.execution.python
     
    +import java.io._
    +
     import scala.collection.JavaConverters._
     import scala.collection.mutable.ArrayBuffer
     
     import net.razorvine.pickle.{Pickler, Unpickler}
     
    -import org.apache.spark.TaskContext
    +import org.apache.spark.{SparkEnv, TaskContext}
     import org.apache.spark.api.python.{ChainedPythonFunctions, PythonRunner}
    +import org.apache.spark.memory.{MemoryConsumer, TaskMemoryManager}
     import org.apache.spark.rdd.RDD
     import org.apache.spark.sql.catalyst.InternalRow
     import org.apache.spark.sql.catalyst.expressions._
     import org.apache.spark.sql.execution.SparkPlan
     import org.apache.spark.sql.types.{DataType, StructField, StructType}
    +import org.apache.spark.unsafe.Platform
    +import org.apache.spark.unsafe.memory.MemoryBlock
    +import org.apache.spark.util.{CompletionIterator, Utils}
    +
    +
    +/**
    + * A RowQueue is an FIFO queue for UnsafeRow.
    + */
    +private[python] trait RowQueue {
    +  /**
    +   * Add a row to the end of it, returns true iff the row has added into 
it.
    +   */
    +  def add(row: UnsafeRow): Boolean
    +
    +  /**
    +   * Retrieve and remove the first row, returns null if it's empty.
    +   */
    +  def remove(): UnsafeRow
    +
    +  /**
    +   * Cleanup all the resources.
    +   */
    +  def close(): Unit
    +}
    +
    +/**
    + * A RowQueue that is based on in-memory page. UnsafeRows are appended 
into it until it's full.
    + * Another thread could read from it at the same time (behind the writer).
    + */
    +private[python] case class InMemoryRowQueue(page: MemoryBlock, fields: 
Int) extends RowQueue {
    +  private val base: AnyRef = page.getBaseObject
    +  private var last = page.getBaseOffset  // for writing
    +  private var first = page.getBaseOffset  // for reading
    +  private val resultRow = new UnsafeRow(fields)
    +
    +  def add(row: UnsafeRow): Boolean = {
    +    if (last + 4 + row.getSizeInBytes > page.getBaseOffset + page.size) {
    +      if (last + 4 <= page.getBaseOffset + page.size) {
    +        Platform.putInt(base, last, -1)
    +      }
    +      return false
    +    }
    +    Platform.putInt(base, last, row.getSizeInBytes)
    +    Platform.copyMemory(row.getBaseObject, row.getBaseOffset, base, last + 
4, row.getSizeInBytes)
    +    last += 4 + row.getSizeInBytes
    +    true
    +  }
    +
    +  def remove(): UnsafeRow = {
    +    if (first + 4 > page.getBaseOffset + page.size || 
Platform.getInt(base, first) < 0) {
    +      null
    +    } else {
    +      val size = Platform.getInt(base, first)
    +      resultRow.pointTo(base, first + 4, size)
    +      first += 4 + size
    +      resultRow
    +    }
    +  }
    +
    +  def close(): Unit = {
    +    // caller should override close() to free page
    +  }
    +}
    +
    +/**
    + * A RowQueue that is based on file in disk. It will stop to push once 
someone start to read
    + * from it.
    + */
    +private[python] case class DiskRowQueue(path: String, fields: Int) extends 
RowQueue {
    +  private var fout = new FileOutputStream(path)
    +  private var out = new DataOutputStream(new BufferedOutputStream(fout))
    +  private var length = 0L
    +
    +  private var fin: FileInputStream = _
    +  private var in: DataInputStream = _
    +  private val resultRow = new UnsafeRow(fields)
    +
    +  def add(row: UnsafeRow): Boolean = {
    +    synchronized {
    +      if (out == null) {
    +        // Another thread is reading, stop writing this one
    +        return false
    +      }
    +    }
    +    out.writeInt(row.getSizeInBytes)
    +    out.write(row.getBytes)
    +    length += 4 + row.getSizeInBytes
    +    true
    +  }
    +
    +  def remove(): UnsafeRow = {
    +    synchronized {
    +      if (out != null) {
    +        out.flush()
    +        out.close()
    +        out = null
    +        fout.close()
    +        fout = null
    +
    +        fin = new FileInputStream(path)
    +        in = new DataInputStream(new BufferedInputStream(fin))
    +      }
    +    }
    +
    +    if (length > 0) {
    +      val size = in.readInt()
    +      assert(4 + size <= length, s"require ${4 + size} bytes for next row")
    +      val bytes = new Array[Byte](size)
    +      in.readFully(bytes)
    +      length -= 4 + size
    +      resultRow.pointTo(bytes, size)
    +      resultRow
    +    } else {
    +      null
    +    }
    +  }
    +
    +  def close(): Unit = {
    +    synchronized {
    +      if (fout != null) {
    +        fout.close()
    +        fout = null
    +      }
    +      if (fin != null) {
    +        fin.close()
    +        fin = null
    +      }
    +    }
    +    val file = new File(path)
    +    if (file.exists()) {
    +      file.delete()
    +    }
    +  }
    +}
    +
    +/**
    + * A RowQueue that has a list of RowQueues, which could be in memory or 
disk.
    + *
    + * HybridRowQueue could be safely appended in one thread, and pulled in 
another thread in the same
    + * time.
    + */
    +private[python] case class HybridRowQueue(
    +    memManager: TaskMemoryManager,
    +    dir: File, fields: Int)
    +  extends MemoryConsumer(memManager) with RowQueue {
    +
    +  // Each buffer should have at least one row
    +  private val queues = new java.util.LinkedList[RowQueue]()
    +
    +  private var writing: RowQueue = _
    +  private var reading: RowQueue = _
    +
    +  private[python] def numQueues(): Int = queues.size()
    +
    +  def spill(size: Long, trigger: MemoryConsumer): Long = {
    +    if (trigger == this) {
    +      // When it's triggered by itself, it should write upcoming rows into 
disk instead of copying
    +      // the rows already in the queue.
    +      return 0L
    +    }
    +    var released = 0L
    +    synchronized {
    +      // poll out all the buffers and add them back in the same order to 
make sure that the rows
    +      // are in correct order.
    +      val n = queues.size()
    +      var i = 0
    +      while (i < n) {
    +        i += 1
    +        val queue = queues.remove()
    +        val newQueue = if (i < n && queue.isInstanceOf[InMemoryRowQueue]) {
    --- End diff --
    
    done


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastruct...@apache.org or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to