Github user ash211 commented on the pull request:
https://github.com/apache/spark/pull/2117#issuecomment-53358527
The reason it's OOM-prone is that if the results of take on the first
partition is size 0, then it runs take(n) on all the remaining partitions *and
brings them back to the driver*. Then it takes from that set of partition
take()s until it fills the N results.
That happens here:
```
val res = sc.runJob(this, (it: Iterator[T]) => it.take(left).toArray, p,
allowLocal = true)
res.foreach(buf ++= _.take(num - buf.size))
```
Between those two lines, the contents of allPartitions.map(_.take(left))
are in the driver's memory all at once.
In my situation I had 10k partitions, the first one was size=0, and I did a
take(300). So after the first one returned 0 results, this pulled 300 results
from each of the remaining 9,999 partitions, which OOM'd the driver.
This patch isn't bulletproof in the sense that it relies on the heuristic
that you'll get to a doubling that contains enough rows to fulfill the take()
before the onslaught of too much data causes an OOM.
An alternative approach would be to evaluate the .take(n) across the
cluster instead and use the new .toLocalIterator call which didn't exist when
this code was first being written. This would handle your observation that 2
waves of calculation ought to be faster than log(n) waves across the cluster.
I'm thinking something more like:
```
rdd.mapPartition(_.take(n)).toLocalIterator().take(n)
```
---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]