Github user davies commented on a diff in the pull request:

    https://github.com/apache/spark/pull/15089#discussion_r81228310
  
    --- Diff: 
sql/core/src/main/scala/org/apache/spark/sql/execution/python/RowQueue.scala ---
    @@ -0,0 +1,276 @@
    +/*
    +* Licensed to the Apache Software Foundation (ASF) under one or more
    +* contributor license agreements.  See the NOTICE file distributed with
    +* this work for additional information regarding copyright ownership.
    +* The ASF licenses this file to You under the Apache License, Version 2.0
    +* (the "License"); you may not use this file except in compliance with
    +* the License.  You may obtain a copy of the License at
    +*
    +*    http://www.apache.org/licenses/LICENSE-2.0
    +*
    +* Unless required by applicable law or agreed to in writing, software
    +* distributed under the License is distributed on an "AS IS" BASIS,
    +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    +* See the License for the specific language governing permissions and
    +* limitations under the License.
    +*/
    +
    +package org.apache.spark.sql.execution.python
    +
    +import java.io._
    +
    +import com.google.common.io.Closeables
    +
    +import org.apache.spark.SparkException
    +import org.apache.spark.memory.{MemoryConsumer, TaskMemoryManager}
    +import org.apache.spark.sql.catalyst.expressions.UnsafeRow
    +import org.apache.spark.unsafe.Platform
    +import org.apache.spark.unsafe.memory.MemoryBlock
    +
    +/**
    + * A RowQueue is an FIFO queue for UnsafeRow.
    + *
    + * This RowQueue is ONLY designed and used for Python UDF, which has only 
one writer and only one
    + * reader, the reader ALWAYS ran behind the writer.
    + */
    +private[python] trait RowQueue {
    +
    +  /**
    +   * Add a row to the end of it, returns true iff the row has added into 
it.
    +   */
    +  def add(row: UnsafeRow): Boolean
    +
    +  /**
    +   * Retrieve and remove the first row, returns null if it's empty.
    +   *
    +   * It can only be called after add is called.
    +   */
    +  def remove(): UnsafeRow
    +
    +  /**
    +   * Cleanup all the resources.
    +   */
    +  def close(): Unit
    +}
    +
    +/**
    + * A RowQueue that is based on in-memory page. UnsafeRows are appended 
into it until it's full.
    + * Another thread could read from it at the same time (behind the writer).
    + *
    + * The format of UnsafeRow in page:
    + * [4 bytes to hold length of record (N)] [N bytes to hold record] [...]
    + */
    +private[python] abstract class InMemoryRowQueue(val page: MemoryBlock, 
numFields: Int)
    +  extends RowQueue {
    +  private val base: AnyRef = page.getBaseObject
    +  private val endOfPage: Long = page.getBaseOffset + page.size
    +  // the first location where a new row would be written
    +  private var writeOffset = page.getBaseOffset
    +  // points to the start of the next row to read
    +  private var readOffset = page.getBaseOffset
    +  private val resultRow = new UnsafeRow(numFields)
    +
    +  def add(row: UnsafeRow): Boolean = {
    +    val size = row.getSizeInBytes
    +    if (writeOffset + 4 + size > endOfPage) {
    +      // if there is not enough space in this page to hold the new record
    +      if (writeOffset + 4 <= endOfPage) {
    +        // if there's extra space at the end of the page, store a special 
"end-of-page" length (-1)
    +        Platform.putInt(base, writeOffset, -1)
    +      }
    +      false
    +    } else {
    +      Platform.putInt(base, writeOffset, size)
    +      Platform.copyMemory(row.getBaseObject, row.getBaseOffset, base, 
writeOffset + 4, size)
    +      writeOffset += 4 + size
    +      true
    +    }
    +  }
    +
    +  def remove(): UnsafeRow = {
    +    if (readOffset + 4 > endOfPage || Platform.getInt(base, readOffset) < 
0) {
    +      null
    +    } else {
    +      val size = Platform.getInt(base, readOffset)
    +      resultRow.pointTo(base, readOffset + 4, size)
    +      readOffset += 4 + size
    +      resultRow
    +    }
    +  }
    +}
    +
    +/**
    + * A RowQueue that is backed by a file on disk. This queue will stop 
accepting new rows once any
    + * reader has begun reading from the queue.
    + */
    +private[python] case class DiskRowQueue(file: File, fields: Int) extends 
RowQueue {
    +  private var out = new DataOutputStream(
    +    new BufferedOutputStream(new FileOutputStream(file.toString)))
    +  private var unreadBytes = 0L
    +
    +  private var in: DataInputStream = _
    +  private val resultRow = new UnsafeRow(fields)
    +
    +  def add(row: UnsafeRow): Boolean = synchronized {
    +    if (out == null) {
    +      // Another thread is reading, stop writing this one
    +      return false
    +    }
    +    out.writeInt(row.getSizeInBytes)
    +    out.write(row.getBytes)
    +    unreadBytes += 4 + row.getSizeInBytes
    +    true
    +  }
    +
    +  def remove(): UnsafeRow = synchronized {
    +    if (out != null) {
    +      out.close()
    +      out = null
    +      in = new DataInputStream(new BufferedInputStream(new 
FileInputStream(file.toString)))
    +    }
    +
    +    if (unreadBytes > 0) {
    +      val size = in.readInt()
    +      val bytes = new Array[Byte](size)
    +      in.readFully(bytes)
    +      unreadBytes -= 4 + size
    +      resultRow.pointTo(bytes, size)
    +      resultRow
    +    } else {
    +      null
    +    }
    +  }
    +
    +  def close(): Unit = synchronized {
    +    Closeables.close(out, true)
    +    out = null
    +    Closeables.close(in, true)
    +    in = null
    +    if (file.exists()) {
    +      file.delete()
    +    }
    +  }
    +}
    +
    +/**
    + * A RowQueue that has a list of RowQueues, which could be in memory or 
disk.
    + *
    + * HybridRowQueue could be safely appended in one thread, and pulled in 
another thread in the same
    + * time.
    + */
    +private[python] case class HybridRowQueue(
    +    memManager: TaskMemoryManager,
    +    tempDir: File,
    +    numFields: Int)
    +  extends MemoryConsumer(memManager) with RowQueue {
    +
    +  // Each buffer should have at least one row
    +  private var queues = new java.util.LinkedList[RowQueue]()
    +
    +  private var writing: RowQueue = _
    +  private var reading: RowQueue = _
    +
    +  // exposed for testing
    +  private[python] def numQueues(): Int = queues.size()
    +
    +  def spill(size: Long, trigger: MemoryConsumer): Long = {
    --- End diff --
    
    No, it will be used as a MemoryConsumer (called by other memory consumer).


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to