This is an automated email from the ASF dual-hosted git repository. weichenxu123 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new daedcd29630 [SPARK-39071][SQL][PYTHON] Add unwrap_udt function for unwrapping UserDefinedType columns daedcd29630 is described below commit daedcd29630de42f6d0858e1e693f4a3d9caf1aa Author: Weichen Xu <weichen...@databricks.com> AuthorDate: Tue May 10 19:29:05 2022 +0800 [SPARK-39071][SQL][PYTHON] Add unwrap_udt function for unwrapping UserDefinedType columns ### What changes were proposed in this pull request? Add unwrap_udt function for unwrapping UserDefinedType columns ### Why are the changes needed? This is useful in open-source project https://github.com/mengxr/pyspark-xgboost ### Does this PR introduce _any_ user-facing change? Yes. new sql function `unwrap_udt` added. ### How was this patch tested? Unit test. Closes #36408 from WeichenXu123/unwrapt_udt. Authored-by: Weichen Xu <weichen...@databricks.com> Signed-off-by: Weichen Xu <weichen...@databricks.com> --- python/pyspark/ml/tests/test_linalg.py | 14 +++++++ python/pyspark/sql/functions.py | 10 +++++ .../spark/sql/catalyst/expressions/UnwrapUDT.scala | 49 ++++++++++++++++++++++ .../scala/org/apache/spark/sql/functions.scala | 9 ++++ .../apache/spark/sql/UserDefinedTypeSuite.scala | 10 +++++ 5 files changed, 92 insertions(+) diff --git a/python/pyspark/ml/tests/test_linalg.py b/python/pyspark/ml/tests/test_linalg.py index dfdd32e98eb..a6e9f4e752e 100644 --- a/python/pyspark/ml/tests/test_linalg.py +++ b/python/pyspark/ml/tests/test_linalg.py @@ -33,6 +33,7 @@ from pyspark.ml.linalg import ( ) from pyspark.testing.mllibutils import MLlibTestCase from pyspark.sql import Row +from pyspark.sql.functions import unwrap_udt class VectorTests(MLlibTestCase): @@ -351,6 +352,19 @@ class VectorUDTTests(MLlibTestCase): else: raise TypeError("expecting a vector but got %r of type %r" % (v, type(v))) + def test_unwrap_udt(self): + df = self.spark.createDataFrame( + [(Vectors.dense(1.0, 2.0, 3.0),), (Vectors.sparse(3, {1: 1.0, 2: 5.5}),)], + ["vec"], + ) + results = df.select(unwrap_udt("vec").alias("v2")).collect() + unwrapped_vec = Row("type", "size", "indices", "values") + expected = [ + Row(v2=unwrapped_vec(1, None, None, [1.0, 2.0, 3.0])), + Row(v2=unwrapped_vec(0, 3, [1, 2], [1.0, 5.5])), + ] + self.assertEquals(results, expected) + class MatrixUDTTests(MLlibTestCase): diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 019f64b5171..aa9aa5ed51b 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -5318,6 +5318,16 @@ def bucket(numBuckets: Union[Column, int], col: "ColumnOrName") -> Column: return _invoke_function("bucket", numBuckets, _to_java_column(col)) +def unwrap_udt(col: "ColumnOrName") -> Column: + """ + Unwrap UDT data type column into its underlying type. + + .. versionadded:: 3.4.0 + + """ + return _invoke_function("unwrap_udt", _to_java_column(col)) + + # ---------------------------- User Defined Function ---------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnwrapUDT.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnwrapUDT.scala new file mode 100644 index 00000000000..cb740672af3 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnwrapUDT.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.types.{DataType, UserDefinedType} + + +/** + * Unwrap UDT data type column into its underlying type. + */ +case class UnwrapUDT(child: Expression) extends UnaryExpression with NonSQLExpression { + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = child.genCode(ctx) + + override def checkInputDataTypes(): TypeCheckResult = { + if (child.dataType.isInstanceOf[UserDefinedType[_]]) { + TypeCheckResult.TypeCheckSuccess + } else { + TypeCheckResult.TypeCheckFailure( + s"Input type should be UserDefinedType but got ${child.dataType.catalogString}") + } + } + override def dataType: DataType = child.dataType.asInstanceOf[UserDefinedType[_]].sqlType + + override def nullSafeEval(input: Any): Any = input + + override def prettyName: String = "unwrap_udt" + + override protected def withNewChildInternal(newChild: Expression): UnwrapUDT = { + copy(child = newChild) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index f6c3bc7e3ce..814a2e472f7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -5499,4 +5499,13 @@ object functions { def call_udf(udfName: String, cols: Column*): Column = withExpr { UnresolvedFunction(udfName, cols.map(_.expr), isDistinct = false) } + + /** + * Unwrap UDT data type column into its underlying type. + * + * @since 3.4.0 + */ + def unwrap_udt(column: Column): Column = withExpr { + UnwrapUDT(column.expr) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index 1f0d971bd72..9bd4a5e6f14 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -272,4 +272,14 @@ class UserDefinedTypeSuite extends QueryTest with SharedSparkSession with Parque assert(result.toSet === Set(FooWithDate(year, "FooFoo", 3), FooWithDate(year, "Foo", 1))) } + + test("Test unwrap_udt function") { + val unwrappedFeatures = pointsRDD.select(unwrap_udt(col("features"))) + .rdd.map { (row: Row) => row.getAs[Seq[Double]](0).toArray } + val unwrappedFeaturesArrays: Array[Array[Double]] = unwrappedFeatures.collect() + assert(unwrappedFeaturesArrays.size === 2) + + java.util.Arrays.equals(unwrappedFeaturesArrays(0), Array(0.1, 1.0)) + java.util.Arrays.equals(unwrappedFeaturesArrays(1), Array(0.2, 2.0)) + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org