Repository: spark
Updated Branches:
  refs/heads/master 8f33731e7 -> 4a5e38f57


[SPARK-19161][PYTHON][SQL] Improving UDF Docstrings

## What changes were proposed in this pull request?

Replaces `UserDefinedFunction` object returned from `udf` with a function 
wrapper providing docstring and arguments information as proposed in 
[SPARK-19161](https://issues.apache.org/jira/browse/SPARK-19161).

### Backward incompatible changes:

- `pyspark.sql.functions.udf` will return a `function` instead of 
`UserDefinedFunction`. To ensure backward compatible public API we use function 
attributes to mimic  `UserDefinedFunction` API (`func` and `returnType` 
attributes).  This should have a minimal impact on the user code.

  An alternative implementation could use dynamical sub-classing. This would 
ensure full backward compatibility but is more fragile in practice.

### Limitations:

Full functionality (retained docstring and argument list) is achieved only in 
the recent Python version. Legacy Python version will preserve only docstrings, 
but not argument list. This should be an acceptable trade-off between achieved 
improvements and overall complexity.

### Possible impact on other tickets:

This can affect 
[SPARK-18777](https://issues.apache.org/jira/browse/SPARK-18777).

## How was this patch tested?

Existing unit tests to ensure backward compatibility, additional tests 
targeting proposed changes.

Author: zero323 <zero...@users.noreply.github.com>

Closes #16534 from zero323/SPARK-19161.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/4a5e38f5
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/4a5e38f5
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/4a5e38f5

Branch: refs/heads/master
Commit: 4a5e38f5747148022988631cae0248ae1affadd3
Parents: 8f33731
Author: zero323 <zero...@users.noreply.github.com>
Authored: Fri Feb 24 08:22:30 2017 -0800
Committer: Holden Karau <hol...@us.ibm.com>
Committed: Fri Feb 24 08:22:30 2017 -0800

----------------------------------------------------------------------
 python/pyspark/sql/functions.py | 11 ++++++++++-
 python/pyspark/sql/tests.py     | 25 +++++++++++++++----------
 2 files changed, 25 insertions(+), 11 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/4a5e38f5/python/pyspark/sql/functions.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index d261720..426a4a8 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -1940,7 +1940,16 @@ def udf(f=None, returnType=StringType()):
     +----------+--------------+------------+
     """
     def _udf(f, returnType=StringType()):
-        return UserDefinedFunction(f, returnType)
+        udf_obj = UserDefinedFunction(f, returnType)
+
+        @functools.wraps(f)
+        def wrapper(*args):
+            return udf_obj(*args)
+
+        wrapper.func = udf_obj.func
+        wrapper.returnType = udf_obj.returnType
+
+        return wrapper
 
     # decorator @udf, @udf() or @udf(dataType())
     if f is None or isinstance(f, (str, DataType)):

http://git-wip-us.apache.org/repos/asf/spark/blob/4a5e38f5/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index abd68bf..fd083e4 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -266,9 +266,6 @@ class SQLTests(ReusedPySparkTestCase):
         self.assertEqual(result[0][0], "a")
         self.assertEqual(result[0][1], "b")
 
-        with self.assertRaises(ValueError):
-            data.select(explode(data.mapfield).alias("a", "b", 
metadata={'max': 99})).count()
-
     def test_and_in_expression(self):
         self.assertEqual(4, self.df.filter((self.df.key <= 10) & 
(self.df.value <= "2")).count())
         self.assertRaises(ValueError, lambda: (self.df.key <= 10) and 
(self.df.value <= "2"))
@@ -578,6 +575,21 @@ class SQLTests(ReusedPySparkTestCase):
             [2, 3.0, "FOO", "foo", "foo", 3, 1.0]
         )
 
+    def test_udf_wrapper(self):
+        from pyspark.sql.functions import udf
+        from pyspark.sql.types import IntegerType
+
+        def f(x):
+            """Identity"""
+            return x
+
+        return_type = IntegerType()
+        f_ = udf(f, return_type)
+
+        self.assertTrue(f.__doc__ in f_.__doc__)
+        self.assertEqual(f, f_.func)
+        self.assertEqual(return_type, f_.returnType)
+
     def test_basic_functions(self):
         rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
         df = self.spark.read.json(rdd)
@@ -963,13 +975,6 @@ class SQLTests(ReusedPySparkTestCase):
         self.assertEqual(self.testData, df.select(df.key, df.value).collect())
         self.assertEqual([Row(value='1')], df.where(df.key == 
1).select(df.value).collect())
 
-    def test_column_alias_metadata(self):
-        df = self.df
-        df_with_meta = df.select(df.key.alias('pk', metadata={'label': 
'Primary Key'}))
-        self.assertEqual(df_with_meta.schema['pk'].metadata['label'], 'Primary 
Key')
-        with self.assertRaises(AssertionError):
-            df.select(df.key.alias('pk', metdata={'label': 'Primary Key'}))
-
     def test_freqItems(self):
         vals = [Row(a=1, b=-2.0) if i % 2 == 0 else Row(a=i, b=i * 1.0) for i 
in range(100)]
         df = self.sc.parallelize(vals).toDF()


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to