Github user ueshin commented on a diff in the pull request:
https://github.com/apache/spark/pull/22013#discussion_r210160909
--- Diff:
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala
---
@@ -497,6 +497,65 @@ case class ArrayAggregate(
override def prettyName: String = "aggregate"
}
+/**
+ * Transform Keys for every entry of the map by applying the
transform_keys function.
+ * Returns map with transformed key entries
+ */
+@ExpressionDescription(
+ usage = "_FUNC_(expr, func) - Transforms elements in a map using the
function.",
+ examples = """
+ Examples:
+ > SELECT _FUNC_(map(array(1, 2, 3), array(1, 2, 3)), (k, v) -> k +
1);
+ map(array(2, 3, 4), array(1, 2, 3))
+ > SELECT _FUNC_(map(array(1, 2, 3), array(1, 2, 3)), (k, v) -> k +
v);
+ map(array(2, 4, 6), array(1, 2, 3))
+ """,
+ since = "2.4.0")
+case class TransformKeys(
+ argument: Expression,
+ function: Expression)
+ extends MapBasedSimpleHigherOrderFunction with CodegenFallback {
+
+ override def nullable: Boolean = argument.nullable
+
+ override def dataType: DataType = {
+ val map = argument.dataType.asInstanceOf[MapType]
+ MapType(function.dataType, map.valueType, map.valueContainsNull)
+ }
+
+ @transient val MapType(keyType, valueType, valueContainsNull) =
argument.dataType
+
+ override def bind(f: (Expression, Seq[(DataType, Boolean)]) =>
LambdaFunction): TransformKeys = {
+ copy(function = f(function, (keyType, false) :: (valueType,
valueContainsNull) :: Nil))
+ }
+
+ @transient lazy val (keyVar, valueVar) = {
+ val LambdaFunction(
+ _, (keyVar: NamedLambdaVariable) :: (valueVar: NamedLambdaVariable) ::
Nil, _) = function
+ (keyVar, valueVar)
+ }
+
+ override def nullSafeEval(inputRow: InternalRow, argumentValue: Any):
Any = {
+ val map = argumentValue.asInstanceOf[MapData]
+ val f = functionForEval
+ val resultKeys = new GenericArrayData(new Array[Any](map.numElements))
+ var i = 0
+ while (i < map.numElements) {
+ keyVar.value.set(map.keyArray().get(i, keyVar.dataType))
+ valueVar.value.set(map.valueArray().get(i, valueVar.dataType))
+ val result = f.eval(inputRow)
+ if (result == null) {
+ throw new RuntimeException("Cannot use null as map key!")
+ }
+ resultKeys.update(i, result)
+ i += 1
+ }
+ new ArrayBasedMapData(resultKeys, map.valueArray())
+ }
+
+ override def prettyName: String = "transform_keys"
+ }
--- End diff --
nit: indent
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]