Github user rxin commented on a diff in the pull request:

    https://github.com/apache/spark/pull/7592#discussion_r35401208
  
    --- Diff: 
sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
 ---
    @@ -149,31 +158,141 @@ private[joins] object HashedRelation {
       }
     }
     
    +/**
    + * An extended CompactBuffer that could grow and update.
    + */
    +private[joins] class MutableCompactBuffer[T: ClassTag] extends 
CompactBuffer[T] {
    +  override def growToSize(newSize: Int): Unit = super.growToSize(newSize)
    +  override def update(i: Int, v: T): Unit = super.update(i, v)
    +}
     
     /**
      * A HashedRelation for UnsafeRow, which is backed by BytesToBytesMap that 
maps the key into a
      * sequence of values.
    - *
    - * TODO(davies): use BytesToBytesMap
      */
     private[joins] final class UnsafeHashedRelation(
         private var hashTable: JavaHashMap[UnsafeRow, 
CompactBuffer[UnsafeRow]])
       extends HashedRelation with Externalizable {
     
    -  def this() = this(null)  // Needed for serialization
    +  private[joins] def this() = this(null)  // Needed for serialization
    +
    +  // Use BytesToBytesMap in executor for better performance (it's created 
when deserialization)
    +  @transient private[this] var binaryMap: BytesToBytesMap = _
    +
    +  // A pool of compact buffers to reduce memory garbage
    +  @transient private[this] val bufferPool = new 
ThreadLocal[MutableCompactBuffer[UnsafeRow]]
     
    -  override def get(key: InternalRow): CompactBuffer[InternalRow] = {
    +  override def get(key: InternalRow): Seq[InternalRow] = {
         val unsafeKey = key.asInstanceOf[UnsafeRow]
    -    // Thanks to type eraser
    -    hashTable.get(unsafeKey).asInstanceOf[CompactBuffer[InternalRow]]
    +
    +    if (binaryMap != null) {
    +      // Used in Broadcast join
    +      val loc = binaryMap.lookup(unsafeKey.getBaseObject, 
unsafeKey.getBaseOffset,
    +        unsafeKey.getSizeInBytes)
    +      if (loc.isDefined) {
    +        // thread-local buffer
    +        var buffer = bufferPool.get()
    +        if (buffer == null) {
    +          buffer = new MutableCompactBuffer[UnsafeRow]
    +          bufferPool.set(buffer)
    +        }
    +
    +        val base = loc.getValueAddress.getBaseObject
    +        var offset = loc.getValueAddress.getBaseOffset
    +        val last = loc.getValueAddress.getBaseOffset + loc.getValueLength
    +        var i = 0
    +        while (offset < last) {
    +          val numFields = PlatformDependent.UNSAFE.getInt(base, offset)
    +          val sizeInBytes = PlatformDependent.UNSAFE.getInt(base, offset + 
4)
    +          offset += 8
    +
    +          // try to re-use the UnsafeRow in buffer, to reduce garbage
    +          buffer.growToSize(i + 1)
    +          if (buffer(i) == null) {
    +            buffer(i) = new UnsafeRow
    +          }
    +          buffer(i).pointTo(base, offset, numFields, sizeInBytes, null)
    +          i += 1
    +          offset += sizeInBytes
    +        }
    +        buffer
    +      } else {
    +        null
    +      }
    +
    +    } else {
    +      // Use the JavaHashMap in Local mode or ShuffleHashJoin
    +      hashTable.get(unsafeKey)
    +    }
       }
     
       override def writeExternal(out: ObjectOutput): Unit = {
    -    writeBytes(out, SparkSqlSerializer.serialize(hashTable))
    +    out.writeInt(hashTable.size())
    +
    +    val iter = hashTable.entrySet().iterator()
    +    while (iter.hasNext) {
    +      val entry = iter.next()
    +      val key = entry.getKey
    +      val values = entry.getValue
    +
    +      // write all the values as single byte array
    +      var totalSize = 0L
    +      var i = 0
    +      while (i < values.size) {
    +        totalSize += values(i).getSizeInBytes + 4 + 4
    +        i += 1
    +      }
    +      assert(totalSize < Integer.MAX_VALUE, "values are too big")
    +
    +      // [key size] [values size] [key bytes] [values bytes]
    +      out.writeInt(key.getSizeInBytes)
    +      out.writeInt(totalSize.toInt)
    +      out.write(key.getBytes)
    +      i = 0
    +      while (i < values.size) {
    +        // [num of fields] [num of bytes] [row bytes]
    +        // write the integer in native order, so they can be read by 
UNSAFE.getInt()
    +        if (ByteOrder.nativeOrder() == ByteOrder.BIG_ENDIAN) {
    +          out.writeInt(values(i).length())
    +          out.writeInt(values(i).getSizeInBytes)
    +        } else {
    +          out.writeInt(Integer.reverseBytes(values(i).length()))
    +          out.writeInt(Integer.reverseBytes(values(i).getSizeInBytes))
    +        }
    +        out.write(values(i).getBytes)
    --- End diff --
    
    can you add a new method to UnsafeRow that writes to an ObjectOutput? It 
should be similar to Unsafe.writeToStream.
     


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to