Github user JoshRosen commented on a diff in the pull request:

    https://github.com/apache/spark/pull/9241#discussion_r42953420
  
    --- Diff: 
core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java ---
    @@ -227,62 +238,147 @@ public BytesToBytesMap(
        */
       public int numElements() { return numElements; }
     
    -  public static final class BytesToBytesMapIterator implements 
Iterator<Location> {
    +  public final class BytesToBytesMapIterator implements Iterator<Location> 
{
     
    -    private final int numRecords;
    -    private final Iterator<MemoryBlock> dataPagesIterator;
    +    private int numRecords;
         private final Location loc;
     
         private MemoryBlock currentPage = null;
    -    private int currentRecordNumber = 0;
    +    private int recordsInPage = 0;
         private Object pageBaseObject;
         private long offsetInPage;
     
         // If this iterator destructive or not. When it is true, it frees each 
page as it moves onto
         // next one.
         private boolean destructive = false;
    -    private BytesToBytesMap bmap;
     
    -    private BytesToBytesMapIterator(
    -        int numRecords, Iterator<MemoryBlock> dataPagesIterator, Location 
loc,
    -        boolean destructive, BytesToBytesMap bmap) {
    +    private LinkedList<UnsafeSorterSpillWriter> spillWriters =
    +      new LinkedList<UnsafeSorterSpillWriter>();
    +    private UnsafeSorterSpillReader reader = null;
    +
    +    private BytesToBytesMapIterator(int numRecords, Location loc, boolean 
destructive) {
           this.numRecords = numRecords;
    -      this.dataPagesIterator = dataPagesIterator;
           this.loc = loc;
           this.destructive = destructive;
    -      this.bmap = bmap;
    -      if (dataPagesIterator.hasNext()) {
    -        advanceToNextPage();
    -      }
    +      destructiveIterator = this;
         }
     
         private void advanceToNextPage() {
    -      if (destructive && currentPage != null) {
    -        dataPagesIterator.remove();
    -        this.bmap.taskMemoryManager.freePage(currentPage);
    -        this.bmap.shuffleMemoryManager.release(currentPage.size());
    +      synchronized (this) {
    +        int nextIdx = dataPages.indexOf(currentPage) + 1;
    +        if (destructive && currentPage != null) {
    +          dataPages.remove(currentPage);
    +          taskMemoryManager.freePage(currentPage);
    +          shuffleMemoryManager.release(currentPage.size());
    +          nextIdx --;
    +        }
    +        if (dataPages.size() > nextIdx) {
    +          currentPage = dataPages.get(nextIdx);
    +          pageBaseObject = currentPage.getBaseObject();
    +          offsetInPage = currentPage.getBaseOffset();
    +          recordsInPage = Platform.getInt(pageBaseObject, offsetInPage);
    +          offsetInPage += 4;
    +        } else {
    +          currentPage = null;
    +          try {
    +            reader = spillWriters.removeFirst().getReader(blockManager);
    +            recordsInPage = -1;
    +          } catch (IOException e) {
    +            // Scala iterator does not handle exception
    +            Platform.throwException(e);
    +          }
    +        }
           }
    -      currentPage = dataPagesIterator.next();
    -      pageBaseObject = currentPage.getBaseObject();
    -      offsetInPage = currentPage.getBaseOffset();
         }
     
         @Override
         public boolean hasNext() {
    -      return currentRecordNumber != numRecords;
    +      return numRecords > 0;
         }
     
         @Override
         public Location next() {
    -      int totalLength = Platform.getInt(pageBaseObject, offsetInPage);
    -      if (totalLength == END_OF_PAGE_MARKER) {
    +      if (recordsInPage == 0) {
             advanceToNextPage();
    -        totalLength = Platform.getInt(pageBaseObject, offsetInPage);
           }
    -      loc.with(currentPage, offsetInPage);
    -      offsetInPage += 4 + totalLength;
    -      currentRecordNumber++;
    -      return loc;
    +      numRecords --;
    +      if (currentPage != null) {
    +        int totalLength = Platform.getInt(pageBaseObject, offsetInPage);
    +        loc.with(currentPage, offsetInPage);
    +        offsetInPage += 4 + totalLength;
    +        recordsInPage --;
    +        return loc;
    +      } else {
    +        assert(reader != null);
    +        if (!reader.hasNext()) {
    +          advanceToNextPage();
    +        }
    +        try {
    +          reader.loadNext();
    +        } catch (IOException e) {
    +          // Scala iterator does not handle exception
    +          Platform.throwException(e);
    +        }
    +        loc.with(reader.getBaseObject(), reader.getBaseOffset(), 
reader.getRecordLength());
    +        return loc;
    +      }
    +    }
    +
    +    public long spill(long numBytes) throws IOException {
    +      synchronized (this) {
    +        if (!destructive || dataPages.size() == 1) {
    +          return 0L;
    +        }
    +
    +        // TODO: use existing ShuffleWriteMetrics
    --- End diff --
    
    Let's chat about this later; I'm not sure whether or not the existing 
spillable collections have consistent semantics in terms accounting for spills 
via shuffle write metrics. I spent a decent amount of time investigating this 
earlier in the year, so I'll see if I can find my notes from then.


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to