Github user mgaido91 commented on a diff in the pull request:
https://github.com/apache/spark/pull/19811#discussion_r157097867
--- Diff:
sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
---
@@ -279,29 +279,29 @@ case class SampleExec(
override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row:
ExprCode): String = {
val numOutput = metricTerm(ctx, "numOutputRows")
- val sampler = ctx.freshName("sampler")
if (withReplacement) {
val samplerClass = classOf[PoissonSampler[UnsafeRow]].getName
val initSampler = ctx.freshName("initSampler")
- val initSamplerFuncName = ctx.addNewFunction(initSampler,
- s"""
- | private void $initSampler() {
- | $sampler = new $samplerClass<UnsafeRow>($upperBound -
$lowerBound, false);
- | java.util.Random random = new java.util.Random(${seed}L);
- | long randomSeed = random.nextLong();
- | int loopCount = 0;
- | while (loopCount < partitionIndex) {
- | randomSeed = random.nextLong();
- | loopCount += 1;
- | }
- | $sampler.setSeed(randomSeed);
- | }
- """.stripMargin.trim)
-
- ctx.addMutableState(s"$samplerClass<UnsafeRow>", sampler,
- s"$initSamplerFuncName();")
+ val sampler = ctx.addMutableState(s"$samplerClass<UnsafeRow>",
"sampleReplace",
+ v => {
+ val initSamplerFuncName = ctx.addNewFunction(initSampler,
+ s"""
+ | private void $initSampler() {
+ | $v = new $samplerClass<UnsafeRow>($upperBound -
$lowerBound, false);
+ | java.util.Random random = new java.util.Random(${seed}L);
+ | long randomSeed = random.nextLong();
+ | int loopCount = 0;
+ | while (loopCount < partitionIndex) {
+ | randomSeed = random.nextLong();
+ | loopCount += 1;
+ | }
+ | $v.setSeed(randomSeed);
+ | }
+ """.stripMargin.trim)
--- End diff --
I think trim is not needed
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]