Github user mgaido91 commented on a diff in the pull request:

    https://github.com/apache/spark/pull/22017#discussion_r208602347
  
    --- Diff: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala
 ---
    @@ -442,3 +442,191 @@ case class ArrayAggregate(
     
       override def prettyName: String = "aggregate"
     }
    +
    +/**
    + * Merges two given maps into a single map by applying function to the 
pair of values with
    + * the same key.
    + */
    +@ExpressionDescription(
    +  usage =
    +    """
    +      _FUNC_(map1, map2, function) - Merges two given maps into a single 
map by applying
    +      function to the pair of values with the same key. For keys only 
presented in one map,
    +      NULL will be passed as the value for the missing key. If an input 
map contains duplicated
    +      keys, only the first entry of the duplicated key is passed into the 
lambda function.
    +    """,
    +  examples = """
    +    Examples:
    +      > SELECT _FUNC_(map(1, 'a', 2, 'b'), map(1, 'x', 2, 'y'), (k, v1, 
v2) -> concat(v1, v2));
    +       {1:"ax",2:"by"}
    +  """,
    +  since = "2.4.0")
    +case class MapZipWith(left: Expression, right: Expression, function: 
Expression)
    +  extends HigherOrderFunction with CodegenFallback {
    +
    +  @transient lazy val functionForEval: Expression = functionsForEval.head
    +
    +  @transient lazy val (keyType, leftValueType, _) =
    +    HigherOrderFunction.mapKeyValueArgumentType(left.dataType)
    +
    +  @transient lazy val (_, rightValueType, _) =
    +    HigherOrderFunction.mapKeyValueArgumentType(right.dataType)
    +
    +  @transient lazy val ordering = TypeUtils.getInterpretedOrdering(keyType)
    +
    +  override def inputs: Seq[Expression] = left :: right :: Nil
    +
    +  override def functions: Seq[Expression] = function :: Nil
    +
    +  override def nullable: Boolean = left.nullable || right.nullable
    +
    +  override def dataType: DataType = MapType(keyType, function.dataType, 
function.nullable)
    +
    +  override def checkInputDataTypes(): TypeCheckResult = {
    +    (left.dataType, right.dataType) match {
    +      case (MapType(k1, _, _), MapType(k2, _, _)) if k1.sameType(k2) =>
    +        TypeUtils.checkForOrderingExpr(k1, s"function $prettyName")
    +      case _ => TypeCheckResult.TypeCheckFailure(s"The input to function 
$prettyName should have " +
    +        s"been two ${MapType.simpleString}s with the same key type, but 
it's " +
    +        s"[${left.dataType.catalogString}, 
${right.dataType.catalogString}].")
    +    }
    +  }
    +
    +  override def bind(f: (Expression, Seq[(DataType, Boolean)]) => 
LambdaFunction): MapZipWith = {
    +    val arguments = Seq((keyType, false), (leftValueType, true), 
(rightValueType, true))
    +    copy(function = f(function, arguments))
    +  }
    +
    +  override def eval(input: InternalRow): Any = {
    +    val value1 = left.eval(input)
    +    if (value1 == null) {
    +      null
    +    } else {
    +      val value2 = right.eval(input)
    +      if (value2 == null) {
    +        null
    +      } else {
    +        nullSafeEval(input, value1, value2)
    +      }
    +    }
    +  }
    +
    +  @transient lazy val LambdaFunction(_, Seq(
    +    keyVar: NamedLambdaVariable,
    +    value1Var: NamedLambdaVariable,
    +    value2Var: NamedLambdaVariable),
    +    _) = function
    +
    +  private def keyTypeSupportsEquals = keyType match {
    +    case BinaryType => false
    +    case _: AtomicType => true
    +    case _ => false
    +  }
    +
    +  @transient private lazy val getKeysWithValueIndexes:
    +      (ArrayData, ArrayData) => Seq[(Any, Array[Option[Int]])] = {
    +    if (keyTypeSupportsEquals) {
    +      getKeysWithIndexesFast
    +    } else {
    +      getKeysWithIndexesBruteForce
    +    }
    +  }
    +
    +  private def assertSizeOfArrayBuffer(size: Int): Unit = {
    +    if (size > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
    +      throw new RuntimeException(s"Unsuccessful try to zip maps with $size 
" +
    +        s"unique keys due to exceeding the array size limit " +
    +        s"${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.")
    +    }
    +  }
    +
    +  private def getKeysWithIndexesFast(keys1: ArrayData, keys2: ArrayData) = 
{
    +    val arrayBuffer = new mutable.ArrayBuffer[(Any, Array[Option[Int]])]
    --- End diff --
    
    you can change a bit approach from the current one. For the first array of 
keys, in particular, you don't need to check neither whether the key is there 
nor the size of the output array, you just need to add them. Then you can add 
the keys from the other one with the logic here.


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to