This is an automated email from the ASF dual-hosted git repository.

weichenxu123 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new daedcd29630 [SPARK-39071][SQL][PYTHON] Add unwrap_udt function for 
unwrapping UserDefinedType columns
daedcd29630 is described below

commit daedcd29630de42f6d0858e1e693f4a3d9caf1aa
Author: Weichen Xu <weichen...@databricks.com>
AuthorDate: Tue May 10 19:29:05 2022 +0800

    [SPARK-39071][SQL][PYTHON] Add unwrap_udt function for unwrapping 
UserDefinedType columns
    
    ### What changes were proposed in this pull request?
    Add unwrap_udt function for unwrapping UserDefinedType columns
    
    ### Why are the changes needed?
    This is useful in open-source project 
https://github.com/mengxr/pyspark-xgboost
    
    ### Does this PR introduce _any_ user-facing change?
    Yes.
    new sql function `unwrap_udt` added.
    
    ### How was this patch tested?
    Unit test.
    
    Closes #36408 from WeichenXu123/unwrapt_udt.
    
    Authored-by: Weichen Xu <weichen...@databricks.com>
    Signed-off-by: Weichen Xu <weichen...@databricks.com>
---
 python/pyspark/ml/tests/test_linalg.py             | 14 +++++++
 python/pyspark/sql/functions.py                    | 10 +++++
 .../spark/sql/catalyst/expressions/UnwrapUDT.scala | 49 ++++++++++++++++++++++
 .../scala/org/apache/spark/sql/functions.scala     |  9 ++++
 .../apache/spark/sql/UserDefinedTypeSuite.scala    | 10 +++++
 5 files changed, 92 insertions(+)

diff --git a/python/pyspark/ml/tests/test_linalg.py 
b/python/pyspark/ml/tests/test_linalg.py
index dfdd32e98eb..a6e9f4e752e 100644
--- a/python/pyspark/ml/tests/test_linalg.py
+++ b/python/pyspark/ml/tests/test_linalg.py
@@ -33,6 +33,7 @@ from pyspark.ml.linalg import (
 )
 from pyspark.testing.mllibutils import MLlibTestCase
 from pyspark.sql import Row
+from pyspark.sql.functions import unwrap_udt
 
 
 class VectorTests(MLlibTestCase):
@@ -351,6 +352,19 @@ class VectorUDTTests(MLlibTestCase):
             else:
                 raise TypeError("expecting a vector but got %r of type %r" % 
(v, type(v)))
 
+    def test_unwrap_udt(self):
+        df = self.spark.createDataFrame(
+            [(Vectors.dense(1.0, 2.0, 3.0),), (Vectors.sparse(3, {1: 1.0, 2: 
5.5}),)],
+            ["vec"],
+        )
+        results = df.select(unwrap_udt("vec").alias("v2")).collect()
+        unwrapped_vec = Row("type", "size", "indices", "values")
+        expected = [
+            Row(v2=unwrapped_vec(1, None, None, [1.0, 2.0, 3.0])),
+            Row(v2=unwrapped_vec(0, 3, [1, 2], [1.0, 5.5])),
+        ]
+        self.assertEquals(results, expected)
+
 
 class MatrixUDTTests(MLlibTestCase):
 
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 019f64b5171..aa9aa5ed51b 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -5318,6 +5318,16 @@ def bucket(numBuckets: Union[Column, int], col: 
"ColumnOrName") -> Column:
     return _invoke_function("bucket", numBuckets, _to_java_column(col))
 
 
+def unwrap_udt(col: "ColumnOrName") -> Column:
+    """
+    Unwrap UDT data type column into its underlying type.
+
+        .. versionadded:: 3.4.0
+
+    """
+    return _invoke_function("unwrap_udt", _to_java_column(col))
+
+
 # ---------------------------- User Defined Function 
----------------------------------
 
 
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnwrapUDT.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnwrapUDT.scala
new file mode 100644
index 00000000000..cb740672af3
--- /dev/null
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnwrapUDT.scala
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, 
ExprCode}
+import org.apache.spark.sql.types.{DataType, UserDefinedType}
+
+
+/**
+ * Unwrap UDT data type column into its underlying type.
+ */
+case class UnwrapUDT(child: Expression) extends UnaryExpression with 
NonSQLExpression {
+
+  override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): 
ExprCode = child.genCode(ctx)
+
+  override def checkInputDataTypes(): TypeCheckResult = {
+    if (child.dataType.isInstanceOf[UserDefinedType[_]]) {
+      TypeCheckResult.TypeCheckSuccess
+    } else {
+      TypeCheckResult.TypeCheckFailure(
+        s"Input type should be UserDefinedType but got 
${child.dataType.catalogString}")
+    }
+  }
+  override def dataType: DataType = 
child.dataType.asInstanceOf[UserDefinedType[_]].sqlType
+
+  override def nullSafeEval(input: Any): Any = input
+
+  override def prettyName: String = "unwrap_udt"
+
+  override protected def withNewChildInternal(newChild: Expression): UnwrapUDT 
= {
+    copy(child = newChild)
+  }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index f6c3bc7e3ce..814a2e472f7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -5499,4 +5499,13 @@ object functions {
   def call_udf(udfName: String, cols: Column*): Column = withExpr {
     UnresolvedFunction(udfName, cols.map(_.expr), isDistinct = false)
   }
+
+  /**
+   * Unwrap UDT data type column into its underlying type.
+   *
+   * @since 3.4.0
+   */
+  def unwrap_udt(column: Column): Column = withExpr {
+    UnwrapUDT(column.expr)
+  }
 }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
index 1f0d971bd72..9bd4a5e6f14 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
@@ -272,4 +272,14 @@ class UserDefinedTypeSuite extends QueryTest with 
SharedSparkSession with Parque
 
     assert(result.toSet === Set(FooWithDate(year, "FooFoo", 3), 
FooWithDate(year, "Foo", 1)))
   }
+
+  test("Test unwrap_udt function") {
+    val unwrappedFeatures = pointsRDD.select(unwrap_udt(col("features")))
+      .rdd.map { (row: Row) => row.getAs[Seq[Double]](0).toArray }
+    val unwrappedFeaturesArrays: Array[Array[Double]] = 
unwrappedFeatures.collect()
+    assert(unwrappedFeaturesArrays.size === 2)
+
+    java.util.Arrays.equals(unwrappedFeaturesArrays(0), Array(0.1, 1.0))
+    java.util.Arrays.equals(unwrappedFeaturesArrays(1), Array(0.2, 2.0))
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to