Github user maropu commented on a diff in the pull request:

    https://github.com/apache/spark/pull/18416#discussion_r124332023
  
    --- Diff: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
 ---
    @@ -834,6 +834,140 @@ case class CollectObjectsToMap private(
       }
     }
     
    +object CollectObjectsToSet {
    +  private val curId = new java.util.concurrent.atomic.AtomicInteger()
    +
    +  /**
    +   * Construct an instance of CollectObjectsToSet case class.
    +   *
    +   * @param function The function applied on the collection elements.
    +   * @param inputData An expression that when evaluated returns a 
collection object.
    +   * @param collClass The type of the resulting collection.
    +   */
    +  def apply(
    +      function: Expression => Expression,
    +      inputData: Expression,
    +      collClass: Class[_]): CollectObjectsToSet = {
    +    val id = curId.getAndIncrement()
    +    val loopValue = s"CollectObjectsToSet_loopValue$id"
    +    val loopIsNull = s"CollectObjectsToSet_loopIsNull$id"
    +    val arrayType = inputData.dataType.asInstanceOf[ArrayType]
    +    val loopVar = LambdaVariable(loopValue, loopIsNull, 
arrayType.elementType)
    +    CollectObjectsToSet(
    +      loopValue, loopIsNull, function(loopVar), inputData, collClass)
    +  }
    +}
    +
    +/**
    + * Expression used to convert a Catalyst Array to an external Scala `Set`.
    + * The collection is constructed using the associated builder, obtained by 
calling `newBuilder`
    + * on the collection's companion object.
    + *
    + * Notice that when we convert a Catalyst array which contains duplicated 
elements to an external
    + * Scala `Set`, the elements will be de-duplicated.
    + *
    + * @param loopValue the name of the loop variable that is used when 
iterating over the value
    + *                       collection, and which is used as input for the 
`lambdaFunction`
    + * @param loopIsNull the nullability of the loop variable that is used 
when iterating over
    + *                        the value collection, and which is used as input 
for the
    + *                        `lambdaFunction`
    + * @param lmbdaFunction A function that takes the `loopValue` as input, 
and is used as
    + *                            a lambda function to handle collection 
elements.
    + * @param inputData An expression that when evaluated returns an array 
object.
    + * @param collClass The type of the resulting collection.
    + */
    +case class CollectObjectsToSet private(
    +    loopValue: String,
    +    loopIsNull: String,
    +    lambdaFunction: Expression,
    +    inputData: Expression,
    +    collClass: Class[_]) extends Expression with NonSQLExpression {
    +
    +  override def nullable: Boolean = inputData.nullable
    +
    +  override def children: Seq[Expression] = lambdaFunction :: inputData :: 
Nil
    +
    +  override def eval(input: InternalRow): Any =
    +    throw new UnsupportedOperationException("Only code-generated 
evaluation is supported")
    +
    +  override def dataType: DataType = ObjectType(collClass)
    +
    +  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    +    // The data with PythonUserDefinedType are actually stored with the 
data type of its sqlType.
    +    def inputDataType(dataType: DataType) = dataType match {
    +      case p: PythonUserDefinedType => p.sqlType
    +      case _ => dataType
    +    }
    +
    +    val arrayType = 
inputDataType(inputData.dataType).asInstanceOf[ArrayType]
    +    val loopValueJavaType = ctx.javaType(arrayType.elementType)
    +    ctx.addMutableState("boolean", loopIsNull, "")
    +    ctx.addMutableState(loopValueJavaType, loopValue, "")
    +    val genFunction = lambdaFunction.genCode(ctx)
    +
    +    val genInputData = inputData.genCode(ctx)
    +    val dataLength = ctx.freshName("dataLength")
    +    val loopIndex = ctx.freshName("loopIndex")
    +    val builderValue = ctx.freshName("builderValue")
    +
    +    val getLength = s"${genInputData.value}.numElements()"
    +    val getLoopVar = ctx.getValue(genInputData.value, 
arrayType.elementType, loopIndex)
    +
    +    // Make a copy of the data if it's unsafe-backed
    +    def makeCopyIfInstanceOf(clazz: Class[_ <: Any], value: String) =
    +      s"$value instanceof ${clazz.getSimpleName}? $value.copy() : $value"
    +    val genFunctionValue =
    +      lambdaFunction.dataType match {
    +        case StructType(_) => makeCopyIfInstanceOf(classOf[UnsafeRow], 
genFunction.value)
    +        case ArrayType(_, _) => 
makeCopyIfInstanceOf(classOf[UnsafeArrayData], genFunction.value)
    +        case MapType(_, _, _) => 
makeCopyIfInstanceOf(classOf[UnsafeMapData], genFunction.value)
    +        case _ => genFunction.value
    +      }
    +
    +    val loopNullCheck = s"$loopIsNull = 
${genInputData.value}.isNullAt($loopIndex);"
    +
    +    val builderClass = classOf[Builder[_, _]].getName
    +    val constructBuilder = s"""
    +      $builderClass $builderValue = 
${collClass.getName}$$.MODULE$$.newBuilder();
    +      $builderValue.sizeHint($dataLength);
    +    """
    +
    +    val appendToBuilder = s"""
    +      if (${genFunction.isNull}) {
    +        $builderValue.$$plus$$eq(null);
    +      } else {
    +        $builderValue.$$plus$$eq($genFunctionValue);
    +      }
    +     """
    +    val getBuilderResult = s"${ev.value} = (${collClass.getName}) 
$builderValue.result();"
    +
    +    val code = s"""
    +      ${genInputData.code}
    +      ${ctx.javaType(dataType)} ${ev.value} = 
${ctx.defaultValue(dataType)};
    +
    +      if (!${genInputData.isNull}) {
    +        int $dataLength = $getLength;
    +        $constructBuilder
    +
    +        int $loopIndex = 0;
    +        while ($loopIndex < $dataLength) {
    +          $loopValue = ($loopValueJavaType) ($getLoopVar);
    +
    +          $loopNullCheck
    --- End diff --
    
    I checked gen'd code;
    ```
    scala> val ds = Seq(Seq(1), Seq(1, 2)).toDF("a").as[Set[Int]]
    ds: org.apache.spark.sql.Dataset[Set[Int]] = [a: array<int>]
    
    scala> ds.printSchema
    root
     |-- a: array (nullable = true)
     |    |-- element: integer (containsNull = false)
    
    scala> ds.filter(_.size > 3).debugCodegen
    ...
    /* 056 */             int filter_loopIndex = 0;
    /* 057 */             while (filter_loopIndex < filter_dataLength) {
    /* 058 */               CollectObjectsToSet_loopValue6 = (int) 
(inputadapter_value.getInt(filter_loopIndex));
    /* 059 */
    /* 060 */               CollectObjectsToSet_loopIsNull6 = 
inputadapter_value.isNullAt(filter_loopIndex);
    /* 061 */
    /* 062 */               if (CollectObjectsToSet_loopIsNull6) {
    /* 063 */                 filter_builderValue.$plus$eq(null);
    /* 064 */               } else {
    /* 065 */                 
filter_builderValue.$plus$eq(CollectObjectsToSet_loopValue6);
    /* 066 */               }
    /* 067 */
    /* 068 */               filter_loopIndex += 1;
    /* 069 */             }
    ...
    ```
    `containsNull = false` though, we need null check in this loop?


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastruct...@apache.org or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to