Github user mn-mikke commented on a diff in the pull request:

    https://github.com/apache/spark/pull/22017#discussion_r208941728
  
    --- Diff: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala
 ---
    @@ -442,3 +442,191 @@ case class ArrayAggregate(
     
       override def prettyName: String = "aggregate"
     }
    +
    +/**
    + * Merges two given maps into a single map by applying function to the 
pair of values with
    + * the same key.
    + */
    +@ExpressionDescription(
    +  usage =
    +    """
    +      _FUNC_(map1, map2, function) - Merges two given maps into a single 
map by applying
    +      function to the pair of values with the same key. For keys only 
presented in one map,
    +      NULL will be passed as the value for the missing key. If an input 
map contains duplicated
    +      keys, only the first entry of the duplicated key is passed into the 
lambda function.
    +    """,
    +  examples = """
    +    Examples:
    +      > SELECT _FUNC_(map(1, 'a', 2, 'b'), map(1, 'x', 2, 'y'), (k, v1, 
v2) -> concat(v1, v2));
    +       {1:"ax",2:"by"}
    +  """,
    +  since = "2.4.0")
    +case class MapZipWith(left: Expression, right: Expression, function: 
Expression)
    +  extends HigherOrderFunction with CodegenFallback {
    +
    +  @transient lazy val functionForEval: Expression = functionsForEval.head
    +
    +  @transient lazy val (keyType, leftValueType, _) =
    +    HigherOrderFunction.mapKeyValueArgumentType(left.dataType)
    +
    +  @transient lazy val (_, rightValueType, _) =
    +    HigherOrderFunction.mapKeyValueArgumentType(right.dataType)
    +
    +  @transient lazy val ordering = TypeUtils.getInterpretedOrdering(keyType)
    +
    +  override def inputs: Seq[Expression] = left :: right :: Nil
    +
    +  override def functions: Seq[Expression] = function :: Nil
    +
    +  override def nullable: Boolean = left.nullable || right.nullable
    +
    +  override def dataType: DataType = MapType(keyType, function.dataType, 
function.nullable)
    +
    +  override def checkInputDataTypes(): TypeCheckResult = {
    +    (left.dataType, right.dataType) match {
    +      case (MapType(k1, _, _), MapType(k2, _, _)) if k1.sameType(k2) =>
    +        TypeUtils.checkForOrderingExpr(k1, s"function $prettyName")
    +      case _ => TypeCheckResult.TypeCheckFailure(s"The input to function 
$prettyName should have " +
    +        s"been two ${MapType.simpleString}s with the same key type, but 
it's " +
    +        s"[${left.dataType.catalogString}, 
${right.dataType.catalogString}].")
    +    }
    +  }
    +
    +  override def bind(f: (Expression, Seq[(DataType, Boolean)]) => 
LambdaFunction): MapZipWith = {
    +    val arguments = Seq((keyType, false), (leftValueType, true), 
(rightValueType, true))
    +    copy(function = f(function, arguments))
    +  }
    +
    +  override def eval(input: InternalRow): Any = {
    +    val value1 = left.eval(input)
    +    if (value1 == null) {
    +      null
    +    } else {
    +      val value2 = right.eval(input)
    +      if (value2 == null) {
    +        null
    +      } else {
    +        nullSafeEval(input, value1, value2)
    +      }
    +    }
    +  }
    +
    +  @transient lazy val LambdaFunction(_, Seq(
    +    keyVar: NamedLambdaVariable,
    +    value1Var: NamedLambdaVariable,
    +    value2Var: NamedLambdaVariable),
    +    _) = function
    +
    +  private def keyTypeSupportsEquals = keyType match {
    +    case BinaryType => false
    +    case _: AtomicType => true
    +    case _ => false
    +  }
    +
    +  @transient private lazy val getKeysWithValueIndexes:
    +      (ArrayData, ArrayData) => Seq[(Any, Array[Option[Int]])] = {
    +    if (keyTypeSupportsEquals) {
    +      getKeysWithIndexesFast
    +    } else {
    +      getKeysWithIndexesBruteForce
    +    }
    +  }
    +
    +  private def assertSizeOfArrayBuffer(size: Int): Unit = {
    +    if (size > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
    +      throw new RuntimeException(s"Unsuccessful try to zip maps with $size 
" +
    +        s"unique keys due to exceeding the array size limit " +
    +        s"${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.")
    +    }
    +  }
    +
    +  private def getKeysWithIndexesFast(keys1: ArrayData, keys2: ArrayData) = 
{
    +    val arrayBuffer = new mutable.ArrayBuffer[(Any, Array[Option[Int]])]
    --- End diff --
    
    Like this idea, thanks!


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to