Github user kiszk commented on a diff in the pull request:
https://github.com/apache/spark/pull/21966#discussion_r207302241
--- Diff:
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
---
@@ -4077,81 +4078,84 @@ case class ArrayExcept(left: Expression, right:
Expression) extends ArraySetLike
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val arrayData = classOf[ArrayData].getName
val i = ctx.freshName("i")
- val pos = ctx.freshName("pos")
val value = ctx.freshName("value")
- val hsValue = ctx.freshName("hsValue")
val size = ctx.freshName("size")
- if (elementTypeSupportEquals) {
- val ptName = CodeGenerator.primitiveTypeName(elementType)
- val unsafeArray = ctx.freshName("unsafeArray")
- val (postFix, openHashElementType, hsJavaTypeName, genHsValue,
- getter, setter, javaTypeName, primitiveTypeName,
arrayDataBuilder) =
- elementType match {
- case ByteType | ShortType | IntegerType =>
- ("$mcI$sp", "Int", "int", s"(int) $value",
- s"get$ptName($i)", s"set$ptName($pos, $value)",
- CodeGenerator.javaType(elementType), ptName,
- s"""
- |${ctx.createUnsafeArray(unsafeArray, size, elementType,
s" $prettyName failed.")}
- |${ev.value} = $unsafeArray;
- """.stripMargin)
- case LongType | FloatType | DoubleType =>
- val signature = elementType match {
- case LongType => "$mcJ$sp"
- case FloatType => "$mcF$sp"
- case DoubleType => "$mcD$sp"
- }
- (signature, CodeGenerator.boxedType(elementType),
- CodeGenerator.javaType(elementType), value,
- s"get$ptName($i)", s"set$ptName($pos, $value)",
- CodeGenerator.javaType(elementType), ptName,
- s"""
- |${ctx.createUnsafeArray(unsafeArray, size, elementType,
s" $prettyName failed.")}
- |${ev.value} = $unsafeArray;
- """.stripMargin)
- case _ =>
- val genericArrayData = classOf[GenericArrayData].getName
- val et = ctx.addReferenceObj("elementType", elementType)
- ("", "Object", "Object", value,
- s"get($i, $et)", s"update($pos, $value)", "Object", "Ref",
- s"${ev.value} = new $genericArrayData(new Object[$size]);")
- }
+ val canUseSpecializedHashSet = elementType match {
+ case ByteType | ShortType | IntegerType | LongType | FloatType |
DoubleType => true
+ case _ => false
+ }
+ if (canUseSpecializedHashSet) {
+ val jt = CodeGenerator.javaType(elementType)
+ val ptName = CodeGenerator.primitiveTypeName(jt)
+
+ def genGetValue(array: String): String =
+ CodeGenerator.getValue(array, elementType, i)
+
+ val (hsPostFix, hsTypeName) = elementType match {
+ // we cast byte/short to int when writing to the hash set.
+ case ByteType | ShortType | IntegerType => ("$mcI$sp", "Int")
+ case LongType => ("$mcJ$sp", ptName)
+ case FloatType => ("$mcF$sp", ptName)
+ case DoubleType => ("$mcD$sp", ptName)
+ }
+
+ // we cast byte/short to int when writing to the hash set.
+ val hsValueCast = elementType match {
+ case ByteType | ShortType => "(int) "
+ case _ => ""
+ }
nullSafeCodeGen(ctx, ev, (array1, array2) => {
val notFoundNullElement = ctx.freshName("notFoundNullElement")
val nullElementIndex = ctx.freshName("nullElementIndex")
val builder = ctx.freshName("builder")
- val array = ctx.freshName("array")
val openHashSet = classOf[OpenHashSet[_]].getName
- val classTag =
s"scala.reflect.ClassTag$$.MODULE$$.$openHashElementType()"
- val hs = ctx.freshName("hs")
+ val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$hsTypeName()"
+ val hashSet = ctx.freshName("hashSet")
val genericArrayData = classOf[GenericArrayData].getName
val arrayBuilder = "scala.collection.mutable.ArrayBuilder"
- val arrayBuilderClass = s"$arrayBuilder$$of$primitiveTypeName"
- val arrayBuilderClassTag = if (primitiveTypeName != "Ref") {
- s"scala.reflect.ClassTag$$.MODULE$$.$primitiveTypeName()"
- } else {
- s"scala.reflect.ClassTag$$.MODULE$$.AnyRef()"
- }
+ val arrayBuilderClass = s"$arrayBuilder$$of$ptName"
+ val arrayBuilderClassTag =
s"scala.reflect.ClassTag$$.MODULE$$.$ptName()"
- def withArray2NullCheck(body: String) =
- if (right.dataType.asInstanceOf[ArrayType].containsNull) {
- s"""
- |if ($array2.isNullAt($i)) {
- | $notFoundNullElement = false;
- |} else {
- | $body
- |}
+ def withArray2NullCheck(body: String): String =
+ if (left.dataType.asInstanceOf[ArrayType].containsNull) {
+ if (right.dataType.asInstanceOf[ArrayType].containsNull) {
+ s"""
+ |if ($array2.isNullAt($i)) {
+ | $notFoundNullElement = false;
+ |} else {
+ | $body
+ |}
""".stripMargin
+ } else {
+ body
+ }
} else {
- body
+ // if array1's element is not nullable, we don't need to track
the null element index.
+ if (right.dataType.asInstanceOf[ArrayType].containsNull) {
+ s"""
+ |if (!$array2.isNullAt($i)) {
+ | $body
+ |}
+ """.stripMargin
+ } else {
+ body
+ }
}
- val array2Body =
+
+ val writeArray2ToHashSet = withArray2NullCheck(
s"""
- |$javaTypeName $value = $array2.$getter;
- |$hsJavaTypeName $hsValue = $genHsValue;
- |$hs.add$postFix($hsValue);
- """.stripMargin
+ |$jt $value = ${genGetValue(array2)};
+ |$hashSet.add$hsPostFix($hsValueCast$value);
+ """.stripMargin)
+
+ // When hitting a null vale, put a null holder in the
ArrayBuilder. Finally we will
--- End diff --
nit: `vale` -> `value`
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]