Github user cloud-fan commented on a diff in the pull request:
https://github.com/apache/spark/pull/16986#discussion_r117840020
--- Diff:
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
---
@@ -652,6 +653,299 @@ case class MapObjects private(
}
}
+object CollectObjectsToMap {
+ private val curId = new java.util.concurrent.atomic.AtomicInteger()
+
+ /**
+ * Construct an instance of CollectObjects case class.
+ *
+ * @param keyFunction The function applied on the key collection
elements.
+ * @param keyInputData An expression that when evaluated returns a key
collection object.
+ * @param keyElementType The data type of key elements in the collection.
+ * @param valueFunction The function applied on the value collection
elements.
+ * @param valueInputData An expression that when evaluated returns a
value collection object.
+ * @param valueElementType The data type of value elements in the
collection.
+ * @param collClass The type of the resulting collection.
+ */
+ def apply(
+ keyFunction: Expression => Expression,
+ keyInputData: Expression,
+ keyElementType: DataType,
+ valueFunction: Expression => Expression,
+ valueInputData: Expression,
+ valueElementType: DataType,
+ collClass: Class[_]): CollectObjectsToMap = {
+ val id = curId.getAndIncrement()
+ val keyLoopValue = s"CollectObjectsToMap_keyLoopValue$id"
+ val keyLoopIsNull = s"CollectObjectsToMap_keyLoopIsNull$id"
+ val keyLoopVar = LambdaVariable(keyLoopValue, keyLoopIsNull,
keyElementType)
+ val valueLoopValue = s"CollectObjectsToMap_valueLoopValue$id"
+ val valueLoopIsNull = s"CollectObjectsToMap_valueLoopIsNull$id"
+ val valueLoopVar = LambdaVariable(valueLoopValue, valueLoopIsNull,
valueElementType)
+ val tupleLoopVar = s"CollectObjectsToMap_tupleLoopValue$id"
+ val builderValue = s"CollectObjectsToMap_builderValue$id"
+ CollectObjectsToMap(
+ keyLoopValue, keyLoopIsNull, keyElementType,
keyFunction(keyLoopVar), keyInputData,
+ valueLoopValue, valueLoopIsNull, valueElementType,
valueFunction(valueLoopVar),
+ valueInputData,
+ tupleLoopVar, collClass, builderValue)
+ }
+}
+
+/**
+ * An equivalent to the [[MapObjects]] case class but returning an
ObjectType containing
+ * a Scala collection constructed using the associated builder, obtained
by calling `newBuilder`
+ * on the collection's companion object.
+ *
+ * @param keyLoopValue the name of the loop variable that is used when
iterating over the key
+ * collection, and which is used as input for the
`keyLambdaFunction`
+ * @param keyLoopIsNull the nullability of the loop variable that is used
when iterating over
+ * the key collection, and which is used as input for
the `keyLambdaFunction`
+ * @param keyLoopVarDataType the data type of the loop variable that is
used when iterating over
+ * the key collection, and which is used as
input for the
+ * `keyLambdaFunction`
+ * @param keyLambdaFunction A function that takes the `keyLoopVar` as
input, and is used as
+ * a lambda function to handle collection
elements.
+ * @param keyInputData An expression that when evaluated returns a
collection object.
+ * @param valueLoopValue the name of the loop variable that is used when
iterating over the value
+ * collection, and which is used as input for the
`valueLambdaFunction`
+ * @param valueLoopIsNull the nullability of the loop variable that is
used when iterating over
+ * the value collection, and which is used as input
for the
+ * `valueLambdaFunction`
+ * @param valueLoopVarDataType the data type of the loop variable that is
used when iterating over
+ * the value collection, and which is used as
input for the
+ * `valueLambdaFunction`
+ * @param valueLambdaFunction A function that takes the `valueLoopVar` as
input, and is used as
+ * a lambda function to handle collection
elements.
+ * @param valueInputData An expression that when evaluated returns a
collection object.
+ * @param tupleLoopValue the name of the loop variable that holds the
tuple to be added to the
+ * resulting map (used only for Scala Map)
+ * @param collClass The type of the resulting collection.
+ * @param builderValue The name of the builder variable used to construct
the resulting collection.
+ */
+case class CollectObjectsToMap private(
+ keyLoopValue: String,
+ keyLoopIsNull: String,
+ keyLoopVarDataType: DataType,
+ keyLambdaFunction: Expression,
+ keyInputData: Expression,
+ valueLoopValue: String,
+ valueLoopIsNull: String,
+ valueLoopVarDataType: DataType,
+ valueLambdaFunction: Expression,
+ valueInputData: Expression,
+ tupleLoopValue: String,
+ collClass: Class[_],
+ builderValue: String) extends Expression with NonSQLExpression {
+
+ override def nullable: Boolean = keyInputData.nullable
+
+ override def children: Seq[Expression] =
+ keyLambdaFunction :: keyInputData :: valueLambdaFunction ::
valueInputData :: Nil
+
+ override def eval(input: InternalRow): Any =
+ throw new UnsupportedOperationException("Only code-generated
evaluation is supported")
+
+ override def dataType: DataType = ObjectType(collClass)
+
+ override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ val keyElementJavaType = ctx.javaType(keyLoopVarDataType)
+ ctx.addMutableState("boolean", keyLoopIsNull, "")
+ ctx.addMutableState(keyElementJavaType, keyLoopValue, "")
+ val genKeyInputData = keyInputData.genCode(ctx)
+ val genKeyFunction = keyLambdaFunction.genCode(ctx)
+ val valueElementJavaType = ctx.javaType(valueLoopVarDataType)
+ ctx.addMutableState("boolean", valueLoopIsNull, "")
+ ctx.addMutableState(valueElementJavaType, valueLoopValue, "")
+ val genValueInputData = valueInputData.genCode(ctx)
+ val genValueFunction = valueLambdaFunction.genCode(ctx)
+ val dataLength = ctx.freshName("dataLength")
+ val loopIndex = ctx.freshName("loopIndex")
+
+ // In RowEncoder, we use `Object` to represent Array or Seq, so we
need to determine the type
+ // of input collection at runtime for this case.
+ val keySeq = ctx.freshName("keySeq")
+ val keyArray = ctx.freshName("keyArray")
+ val valueSeq = ctx.freshName("valueSeq")
+ val valueArray = ctx.freshName("valueArray")
+ def determineCollectionType(inputData: Expression, genInputData:
ExprCode,
+ elementJavaType: String, seq: String,
array: String) =
+ inputData.dataType match {
+ case ObjectType(cls) if cls == classOf[Object] =>
+ val seqClass = classOf[Seq[_]].getName
+ s"""
+ $seqClass $seq = null;
+ $elementJavaType[] $array = null;
+ if (${genInputData.value}.getClass().isArray()) {
+ $array = ($elementJavaType[]) ${genInputData.value};
+ } else {
+ $seq = ($seqClass) ${genInputData.value};
+ }
+ """
+ case _ => ""
+ }
+ val determineKeyCollectionType = determineCollectionType(
+ keyInputData, genKeyInputData, keyElementJavaType, keySeq, keyArray)
+ val determineValueCollectionType = determineCollectionType(
+ valueInputData, genValueInputData, valueElementJavaType, valueSeq,
valueArray)
+
+ // The data with PythonUserDefinedType are actually stored with the
data type of its sqlType.
+ // When we want to apply MapObjects on it, we have to use it.
+ def inputDataType(inputData: Expression) = inputData.dataType match {
+ case p: PythonUserDefinedType => p.sqlType
+ case _ => inputData.dataType
+ }
+ val keyInputDataType = inputDataType(keyInputData)
+ val valueInputDataType = inputDataType(valueInputData)
+
+ def lengthAndLoopVar(inputDataType: DataType, genInputData: ExprCode,
+ seq: String, array: String) =
+ inputDataType match {
+ case ObjectType(cls) if classOf[Seq[_]].isAssignableFrom(cls) =>
+ s"${genInputData.value}.size()" ->
s"${genInputData.value}.apply($loopIndex)"
+ case ObjectType(cls) if cls.isArray =>
+ s"${genInputData.value}.length" ->
s"${genInputData.value}[$loopIndex]"
+ case ObjectType(cls) if
classOf[java.util.List[_]].isAssignableFrom(cls) =>
+ s"${genInputData.value}.size()" ->
s"${genInputData.value}.get($loopIndex)"
+ case ArrayType(et, _) =>
+ s"${genInputData.value}.numElements()" ->
ctx.getValue(genInputData.value, et, loopIndex)
+ case ObjectType(cls) if cls == classOf[Object] =>
+ s"$seq == null ? $array.length : $seq.size()" ->
+ s"$seq == null ? $array[$loopIndex] : $seq.apply($loopIndex)"
+ }
+ val ((getKeyLength, getKeyLoopVar), (getValueLength, getValueLoopVar))
= (
+ lengthAndLoopVar(inputDataType(keyInputData), genKeyInputData,
keySeq, keyArray),
+ lengthAndLoopVar(inputDataType(valueInputData), genValueInputData,
valueSeq, valueArray)
+ )
+
+ // Make a copy of the data if it's unsafe-backed
+ def makeCopyIfInstanceOf(clazz: Class[_ <: Any], value: String) =
+ s"$value instanceof ${clazz.getSimpleName}? $value.copy() : $value"
+ def genFunctionValue(lambdaFunction: Expression, genFunction:
ExprCode) =
+ lambdaFunction.dataType match {
+ case StructType(_) => makeCopyIfInstanceOf(classOf[UnsafeRow],
genFunction.value)
+ case ArrayType(_, _) =>
makeCopyIfInstanceOf(classOf[UnsafeArrayData], genFunction.value)
+ case MapType(_, _, _) =>
makeCopyIfInstanceOf(classOf[UnsafeMapData], genFunction.value)
+ case _ => genFunction.value
+ }
+ val genKeyFunctionValue = genFunctionValue(keyLambdaFunction,
genKeyFunction)
+ val genValueFunctionValue = genFunctionValue(valueLambdaFunction,
genValueFunction)
+
+ def loopNullCheck(genInputData: ExprCode, inputDataType: DataType,
+ loopIsNull: String, loopValue: String) =
+ inputDataType match {
+ case _: ArrayType => s"$loopIsNull =
${genInputData.value}.isNullAt($loopIndex);"
+ // The element of primitive array will never be null.
+ case ObjectType(cls) if cls.isArray &&
cls.getComponentType.isPrimitive =>
+ s"$loopIsNull = false"
+ case _ => s"$loopIsNull = $loopValue == null;"
+ }
+ val keyLoopNullCheck =
+ loopNullCheck(genKeyInputData, keyInputDataType, keyLoopIsNull,
keyLoopValue)
+ val valueLoopNullCheck =
+ loopNullCheck(genValueInputData, valueInputDataType,
valueLoopIsNull, valueLoopValue)
+
+ val constructBuilder = collClass match {
+ // Scala Map
+ case cls if classOf[scala.collection.Map[_,
_]].isAssignableFrom(cls) =>
+ val builderClass = classOf[Builder[_, _]].getName
+ s"""
+ $builderClass $builderValue =
${collClass.getName}$$.MODULE$$.newBuilder();
+ $builderValue.sizeHint($dataLength);
+ """
+ // Java Map, AbstractMap => HashMap
+ case cls if classOf[java.util.Map[_, _]] == cls ||
+ classOf[java.util.AbstractMap[_, _]] == cls =>
+ val builderClass = classOf[java.util.HashMap[_, _]].getName
+ s"$builderClass $builderValue = new $builderClass($dataLength);"
+ // Java SortedMap, NavigableMap => TreeMap
+ case cls if classOf[java.util.SortedMap[_, _]] == cls ||
+ classOf[java.util.NavigableMap[_, _]] == cls =>
+ val builderClass = classOf[java.util.TreeMap[_, _]].getName
+ s"$builderClass $builderValue = new $builderClass();"
+ // Java ConcurrentMap => ConcurrentHashMap
+ case cls if classOf[java.util.concurrent.ConcurrentMap[_, _]] == cls
=>
+ val builderClass =
classOf[java.util.concurrent.ConcurrentHashMap[_, _]].getName
+ s"$builderClass $builderValue = new $builderClass();"
+ // Java ConcurrentNavigableMap => ConcurrentSkipListMap
+ case cls if classOf[java.util.concurrent.ConcurrentNavigableMap[_,
_]] == cls =>
+ val builderClass =
classOf[java.util.concurrent.ConcurrentSkipListMap[_, _]].getName
+ s"$builderClass $builderValue = new $builderClass();"
+ // Java concrete Map implementation
+ case cls =>
+ val builderClass = classOf[java.util.Map[_, _]].getName
+ // Check for constructor with capacity specification
+ if (Try(cls.getConstructor(Integer.TYPE)).isSuccess) {
+ s"$builderClass $builderValue = new ${cls.getName}($dataLength);"
+ } else {
+ s"$builderClass $builderValue = new ${cls.getName}();"
+ }
+ }
+
+ val (appendToBuilder, getBuilderResult) =
+ if (classOf[scala.collection.Map[_, _]].isAssignableFrom(collClass))
{
+ val tupleClass = classOf[(_, _)].getName
+ s"""
+ $tupleClass $tupleLoopValue;
+
+ if (${genValueFunction.isNull}) {
+ $tupleLoopValue = new $tupleClass($genKeyFunctionValue, null);
+ } else {
+ $tupleLoopValue = new $tupleClass($genKeyFunctionValue,
$genValueFunctionValue);
+ }
+
+ $builderValue.$$plus$$eq($tupleLoopValue);
--- End diff --
this is ok, but it will be great if there is a way to avoid creating the
tuple every time.
---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]