Repository: spark
Updated Branches:
  refs/heads/branch-2.0 901ab0694 -> 56a842635


[SPARK-15382][SQL] Fix a bug in sampling with replacement

## What changes were proposed in this pull request?
This pr to fix a bug below in sampling with replacement
```
val df = Seq((1, 0), (2, 0), (3, 0)).toDF("a", "b")
df.sample(true, 2.0).withColumn("c", 
monotonically_increasing_id).select($"c").show
+---+
|  c|
+---+
|  0|
|  1|
|  1|
|  1|
|  2|
+---+
```

## How was this patch tested?
Added a test in `DataFrameSuite`.

Author: Takeshi YAMAMURO <[email protected]>

Closes #14800 from maropu/FixSampleBug.

(cherry picked from commit cd0ed31ea9965563a9b1ea3e8bfbeaf8347cacd9)
Signed-off-by: Sean Owen <[email protected]>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/56a84263
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/56a84263
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/56a84263

Branch: refs/heads/branch-2.0
Commit: 56a8426355b494bc085b649ae6c8245f0a039e3a
Parents: 901ab06
Author: Takeshi YAMAMURO <[email protected]>
Authored: Sat Aug 27 08:42:41 2016 +0100
Committer: Sean Owen <[email protected]>
Committed: Sat Aug 27 08:42:51 2016 +0100

----------------------------------------------------------------------
 .../apache/spark/sql/execution/basicPhysicalOperators.scala   | 1 +
 .../src/test/scala/org/apache/spark/sql/DataFrameSuite.scala  | 7 +++++++
 2 files changed, 8 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/56a84263/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
index 90bf817..a544371 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
@@ -260,6 +260,7 @@ case class SampleExec(
     if (withReplacement) {
       val samplerClass = classOf[PoissonSampler[UnsafeRow]].getName
       val initSampler = ctx.freshName("initSampler")
+      ctx.copyResult = true
       ctx.addMutableState(s"$samplerClass<UnsafeRow>", sampler,
         s"$initSampler();")
 

http://git-wip-us.apache.org/repos/asf/spark/blob/56a84263/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 55edbe2..da5c538 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -1558,4 +1558,11 @@ class DataFrameSuite extends QueryTest with 
SharedSQLContext {
     val df = spark.createDataFrame(rdd, StructType(schemas), false)
     assert(df.persist.take(1).apply(0).toSeq(100).asInstanceOf[Long] == 100)
   }
+
+  test("copy results for sampling with replacement") {
+    val df = Seq((1, 0), (2, 0), (3, 0)).toDF("a", "b")
+    val sampleDf = df.sample(true, 2.00)
+    val d = sampleDf.withColumn("c", 
monotonically_increasing_id).select($"c").collect
+    assert(d.size == d.distinct.size)
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to