This is an automated email from the ASF dual-hosted git repository. eyal pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/datafu.git
The following commit(s) were added to refs/heads/master by this push: new 3bc8f60 DATAFU-165 Added dedupRandomN method and collectLimitedList UDAF functionality 3bc8f60 is described below commit 3bc8f60fb3c57ef95501b466431238658505518b Author: Rahamim, Ben <braha...@paypal.com> AuthorDate: Thu Aug 4 12:44:33 2022 +0300 DATAFU-165 Added dedupRandomN method and collectLimitedList UDAF functionality Signed-off-by: Eyal Allweil <e...@apache.org> --- .../src/main/scala/datafu/spark/DataFrameOps.scala | 5 +- .../src/main/scala/datafu/spark/SparkDFUtils.scala | 18 +++++++ .../utils/overwrites/SparkOverwriteUDAFs.scala | 61 +++++++++++++++++++++- .../test/scala/datafu/spark/TestSparkUDAFs.scala | 22 ++++++++ 4 files changed, 103 insertions(+), 3 deletions(-) diff --git a/datafu-spark/src/main/scala/datafu/spark/DataFrameOps.scala b/datafu-spark/src/main/scala/datafu/spark/DataFrameOps.scala index 04ae90a..5b4f42d 100644 --- a/datafu-spark/src/main/scala/datafu/spark/DataFrameOps.scala +++ b/datafu-spark/src/main/scala/datafu/spark/DataFrameOps.scala @@ -103,6 +103,9 @@ object DataFrameOps { def explodeArray(arrayCol: Column, alias: String) = - SparkDFUtils.explodeArray(df, arrayCol, alias) + SparkDFUtils.explodeArray(df, arrayCol, alias) + + def dedupRandomN(df: DataFrame, groupCol: Column, maxSize: Int): DataFrame = + SparkDFUtils.dedupRandomN(df, groupCol, maxSize) } } diff --git a/datafu-spark/src/main/scala/datafu/spark/SparkDFUtils.scala b/datafu-spark/src/main/scala/datafu/spark/SparkDFUtils.scala index 79b51eb..4fd068f 100644 --- a/datafu-spark/src/main/scala/datafu/spark/SparkDFUtils.scala +++ b/datafu-spark/src/main/scala/datafu/spark/SparkDFUtils.scala @@ -129,6 +129,10 @@ class SparkDFUtilsBridge { ) } + def dedupRandomN(df: DataFrame, groupCol: Column, maxSize: Int): DataFrame = { + SparkDFUtils.dedupRandomN(df, groupCol, maxSize) + } + private def convertJavaListToSeq[T](list: JavaList[T]): Seq[T] = { scala.collection.JavaConverters .asScalaIteratorConverter(list.iterator()) @@ -550,4 +554,18 @@ object SparkDFUtils { val exprs = (0 until arrSize).map(i => arrayCol.getItem(i).alias(s"$alias$i")) df.select((col("*") +: exprs):_*) } + + /** + * Used get the random n records in each group. Uses an efficient implementation + * that doesn't order the data so it can handle large amounts of data. + * + * @param df DataFrame to operate on + * @param groupCol column to group by the records + * @param maxSize The maximal number of rows per group + * @return DataFrame representing the data after the operation + */ + def dedupRandomN(df: DataFrame, groupCol: Column, maxSize: Int): DataFrame = { + df.groupBy(groupCol).agg(SparkOverwriteUDAFs.collectLimitedList(expr("struct(*)"), maxSize).as("list")) + .select(groupCol,expr("explode(list)")) + } } diff --git a/datafu-spark/src/main/scala/spark/utils/overwrites/SparkOverwriteUDAFs.scala b/datafu-spark/src/main/scala/spark/utils/overwrites/SparkOverwriteUDAFs.scala index 04d68d6..2dcb5be 100644 --- a/datafu-spark/src/main/scala/spark/utils/overwrites/SparkOverwriteUDAFs.scala +++ b/datafu-spark/src/main/scala/spark/utils/overwrites/SparkOverwriteUDAFs.scala @@ -19,17 +19,23 @@ package org.apache.spark.sql.datafu.types import org.apache.spark.sql.Column +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.DeclarativeAggregate +import org.apache.spark.sql.catalyst.expressions.aggregate.{Collect, DeclarativeAggregate, ImperativeAggregate} +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, BinaryComparison, ExpectsInputTypes, Expression, GreaterThan, If, IsNull, LessThan, Literal} import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, DataType} +import scala.collection.generic.Growable +import scala.collection.mutable + object SparkOverwriteUDAFs { def minValueByKey(key: Column, value: Column): Column = Column(MinValueByKey(key.expr, value.expr).toAggregateExpression(false)) def maxValueByKey(key: Column, value: Column): Column = Column(MaxValueByKey(key.expr, value.expr).toAggregateExpression(false)) + def collectLimitedList(e: Column, maxSize: Int): Column = + Column(CollectLimitedList(e.expr, howMuchToTake = maxSize).toAggregateExpression(false)) } case class MinValueByKey(child1: Expression, child2: Expression) @@ -88,3 +94,54 @@ abstract class ExtramumValueByKey( override lazy val evaluateExpression: AttributeReference = data } + +/** * + * + * This code is copied from CollectList, just modified the method it extends + * Copied originally from https://github.com/apache/spark/blob/branch-2.3/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala + * + */ +case class CollectLimitedList(child: Expression, + mutableAggBufferOffset: Int = 0, + inputAggBufferOffset: Int = 0, + howMuchToTake: Int = 10) extends LimitedCollect[mutable.ArrayBuffer[Any]](howMuchToTake) { + + def this(child: Expression) = this(child, 0, 0) + + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = + copy(mutableAggBufferOffset = newMutableAggBufferOffset) + + override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = + copy(inputAggBufferOffset = newInputAggBufferOffset) + + override def createAggregationBuffer(): mutable.ArrayBuffer[Any] = mutable.ArrayBuffer.empty + + override def prettyName: String = "collect_limited_list" + +} + +/** * + * + * This modifies the collect list / set to keep only howMuchToTake random elements + * + */ +abstract class LimitedCollect[T <: Growable[Any] with Iterable[Any]](howMuchToTake: Int) extends Collect[T] with Serializable { + + override def update(buffer: T, input: InternalRow): T = { + if (buffer.size < howMuchToTake) + super.update(buffer, input) + else + buffer + } + + override def merge(buffer: T, other: T): T = { + if (buffer.size == howMuchToTake) + buffer + else if (other.size == howMuchToTake) + other + else { + val howMuchToTakeFromOtherBuffer = howMuchToTake - buffer.size + buffer ++= other.take(howMuchToTakeFromOtherBuffer) + } + } +} diff --git a/datafu-spark/src/test/scala/datafu/spark/TestSparkUDAFs.scala b/datafu-spark/src/test/scala/datafu/spark/TestSparkUDAFs.scala index e80e71f..aadd059 100644 --- a/datafu-spark/src/test/scala/datafu/spark/TestSparkUDAFs.scala +++ b/datafu-spark/src/test/scala/datafu/spark/TestSparkUDAFs.scala @@ -258,4 +258,26 @@ class UdafTests extends FunSuite with DataFrameSuiteBase { .agg(countDistinctUpTo6($"col_ord").as("col_ord"))) } + test("test_limited_collect_list") { + + val maxSize = 10 + + val rows = (1 to 30).flatMap(x => (1 to x).map(n => (x, n, "some-string " + n))).toDF("num1", "num2", "str") + + rows.show(10, false) + + import org.apache.spark.sql.functions._ + + val result = rows.groupBy("num1").agg(SparkOverwriteUDAFs.collectLimitedList(expr("struct(*)"), maxSize).as("list")) + .withColumn("list_size", expr("size(list)")) + + result.show(10, false) + + SparkDFUtils.dedupRandomN(rows,$"num1",10).show(10,false) + + val rows_different = result.filter(s"case when num1 > $maxSize then $maxSize else num1 end != list_size") + + Assert.assertEquals(0, rows_different.count()) + + } }