Github user mn-mikke commented on a diff in the pull request:
https://github.com/apache/spark/pull/21386#discussion_r189722777
--- Diff:
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
---
@@ -555,6 +557,100 @@ case class ArraySort(child: Expression) extends
UnaryExpression with ArraySortLi
override def prettyName: String = "array_sort"
}
+
+/**
+ * Returns a random permutation of the given array..
+ */
+@ExpressionDescription(
+ usage = "_FUNC_(array) - Returns a random permutation of the given
array.",
+ examples = """
+ Examples:
+ > SELECT _FUNC_(array(1, 20, 3, 5));
+ [3, 1, 5, 20]
+ > SELECT _FUNC_(array(1, 20, null, 3));
+ [20, null, 3, 1]
+ """, since = "2.4.0")
+case class Shuffle(child: Expression) extends UnaryExpression with
ImplicitCastInputTypes {
+
+ override def nullable: Boolean = true
+
+ override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType)
+
+ override def dataType: DataType = child.dataType
+
+ lazy val elementType: DataType =
dataType.asInstanceOf[ArrayType].elementType
+
+ override protected def doGenCode(ctx: CodegenContext, ev: ExprCode):
ExprCode = {
+ nullSafeCodeGen(ctx, ev, c => shuffleArrayCodeGen(ctx, ev, c))
+ }
+
+ private def shuffleArrayCodeGen(ctx: CodegenContext, ev: ExprCode,
childName: String): String = {
+ val length = ctx.freshName("length")
+ val javaElementType = CodeGenerator.javaType(elementType)
+ val isPrimitiveType = CodeGenerator.isPrimitiveType(elementType)
+
+ val initialization = if (isPrimitiveType) {
+ s"${ev.value} = $childName.copy()"
+ } else {
+ s"""
+ |${ev.value} = new ${classOf[GenericArrayData].getName()}(new
Object[$length]);
+ |for (int j = 0; j < $childName.numElements(); j++) {
+ | ${ev.value}.update(j, ${CodeGenerator.getValue(childName,
elementType, "j")});
+ |}
+ """.stripMargin
+ }
+
+ val swapAssigments = if (isPrimitiveType) {
+ val setFunc = "set" + CodeGenerator.primitiveTypeName(elementType)
+ val getCall = (index: String) => CodeGenerator.getValue(ev.value,
elementType, index)
+ s"""
+ |boolean isNullAtK = ${ev.value}.isNullAt(k);
+ |boolean isNullAtL = ${ev.value}.isNullAt(l);
+ |if (!isNullAtK) {
+ | $javaElementType el = ${getCall("k")};
+ | if(!isNullAtL) {
+ | ${ev.value}.$setFunc(k, ${getCall("l")});
+ | } else {
+ | ${ev.value}.setNullAt(k);
+ | }
+ | ${ev.value}.$setFunc(l, el);
+ |} else if (!isNullAtL) {
+ | ${ev.value}.$setFunc(k, ${getCall("l")});
+ | ${ev.value}.setNullAt(l);
+ |}
+ """.stripMargin
+ } else {
+ s"""
+ |$javaElementType el = ${CodeGenerator.getValue(ev.value,
elementType, "l")};
+ |${ev.value}.update(l, ${CodeGenerator.getValue(ev.value,
elementType, "k")});
+ |${ev.value}.update(k, el);
+ """.stripMargin
+ }
+
+ val randomClass = classOf[Random].getName
+ val rand = ctx.freshName("rand")
+
+ s"""
+ |final int $length = $childName.numElements();
+ |$randomClass $rand = new $randomClass();
+ |$initialization;
+ |for (int k = $length - 1; k >= 1; k--) {
+ | int l = $rand.nextInt(k + 1);
+ | $swapAssigments
+ |}
+ """.stripMargin
+ }
+
+ override protected def nullSafeEval(input: Any): Any = input match {
+ case a: ArrayData =>
+ new GenericArrayData(scala.util.
--- End diff --
line wrapping
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]