Github user ueshin commented on a diff in the pull request: https://github.com/apache/spark/pull/21802#discussion_r205534846 --- Diff: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala --- @@ -1184,6 +1184,110 @@ case class ArraySort(child: Expression) extends UnaryExpression with ArraySortLi override def prettyName: String = "array_sort" } +/** + * Returns a random permutation of the given array. + */ +@ExpressionDescription( + usage = "_FUNC_(array) - Returns a random permutation of the given array.", + examples = """ + Examples: + > SELECT _FUNC_(array(1, 20, 3, 5)); + [3, 1, 5, 20] + > SELECT _FUNC_(array(1, 20, null, 3)); + [20, null, 3, 1] + """, since = "2.4.0") +case class Shuffle(child: Expression, randomSeed: Option[Long] = None) + extends UnaryExpression with ExpectsInputTypes with Stateful { + + def this(child: Expression) = this(child, None) + + override lazy val resolved: Boolean = + childrenResolved && checkInputDataTypes().isSuccess && randomSeed.isDefined + + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType) + + override def dataType: DataType = child.dataType + + @transient lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType + + @transient private[this] var random: RandomIndicesGenerator = _ + + override protected def initializeInternal(partitionIndex: Int): Unit = { + random = RandomIndicesGenerator(randomSeed.get + partitionIndex) + } + + override protected def evalInternal(input: InternalRow): Any = { + val value = child.eval(input) + if (value == null) { + null + } else { + val source = value.asInstanceOf[ArrayData] + val numElements = source.numElements() + val indices = random.getNextIndices(numElements) + new GenericArrayData(indices.map(source.get(_, elementType))) + } + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, c => shuffleArrayCodeGen(ctx, ev, c)) + } + + private def shuffleArrayCodeGen(ctx: CodegenContext, ev: ExprCode, childName: String): String = { + val randomClass = classOf[RandomIndicesGenerator].getName + + val rand = ctx.addMutableState(randomClass, "rand", forceInline = true) + ctx.addPartitionInitializationStatement( + s"$rand = new $randomClass(${randomSeed.get}L + partitionIndex);") + + val isPrimitiveType = CodeGenerator.isPrimitiveType(elementType) + + val numElements = ctx.freshName("numElements") + val arrayData = ctx.freshName("arrayData") --- End diff -- Actually we need a new variable to use `ctx.createUnsafeArray()` which declares a new variable in it for now whereas `ev.value` is already declared.
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org