Repository: spark Updated Branches: refs/heads/master e0eeb0f89 -> 7b64f7aa0
[SPARK-18541][PYTHON] Add metadata parameter to pyspark.sql.Column.alias() ## What changes were proposed in this pull request? Add a `metadata` keyword parameter to `pyspark.sql.Column.alias()` to allow users to mix-in metadata while manipulating `DataFrame`s in `pyspark`. Without this, I believe it was necessary to pass back through `SparkSession.createDataFrame` each time a user wanted to manipulate `StructField.metadata` in `pyspark`. This pull request also improves consistency between the Scala and Python APIs (i.e. I did not add any functionality that was not already in the Scala API). Discussed ahead of time on JIRA with marmbrus ## How was this patch tested? Added unit tests (and doc tests). Ran the pertinent tests manually. Author: Sheamus K. Parkes <shea.par...@milliman.com> Closes #16094 from shea-parkes/pyspark-column-alias-metadata. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/7b64f7aa Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/7b64f7aa Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/7b64f7aa Branch: refs/heads/master Commit: 7b64f7aa03a49adca5fcafe6fff422823b587514 Parents: e0eeb0f Author: Sheamus K. Parkes <shea.par...@milliman.com> Authored: Tue Feb 14 09:57:43 2017 -0800 Committer: Holden Karau <hol...@us.ibm.com> Committed: Tue Feb 14 09:57:43 2017 -0800 ---------------------------------------------------------------------- python/pyspark/sql/column.py | 26 +++++++++++++++++++++++--- python/pyspark/sql/tests.py | 10 ++++++++++ 2 files changed, 33 insertions(+), 3 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/7b64f7aa/python/pyspark/sql/column.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py index 73c8672..0df187a 100644 --- a/python/pyspark/sql/column.py +++ b/python/pyspark/sql/column.py @@ -17,6 +17,7 @@ import sys import warnings +import json if sys.version >= '3': basestring = str @@ -303,19 +304,38 @@ class Column(object): isNotNull = _unary_op("isNotNull", "True if the current expression is not null.") @since(1.3) - def alias(self, *alias): + def alias(self, *alias, **kwargs): """ Returns this column aliased with a new name or names (in the case of expressions that return more than one column, such as explode). + :param alias: strings of desired column names (collects all positional arguments passed) + :param metadata: a dict of information to be stored in ``metadata`` attribute of the + corresponding :class: `StructField` (optional, keyword only argument) + + .. versionchanged:: 2.2 + Added optional ``metadata`` argument. + >>> df.select(df.age.alias("age2")).collect() [Row(age2=2), Row(age2=5)] + >>> df.select(df.age.alias("age3", metadata={'max': 99})).schema['age3'].metadata['max'] + 99 """ + metadata = kwargs.pop('metadata', None) + assert not kwargs, 'Unexpected kwargs where passed: %s' % kwargs + + sc = SparkContext._active_spark_context if len(alias) == 1: - return Column(getattr(self._jc, "as")(alias[0])) + if metadata: + jmeta = sc._jvm.org.apache.spark.sql.types.Metadata.fromJson( + json.dumps(metadata)) + return Column(getattr(self._jc, "as")(alias[0], jmeta)) + else: + return Column(getattr(self._jc, "as")(alias[0])) else: - sc = SparkContext._active_spark_context + if metadata: + raise ValueError('metadata can only be provided for a single column') return Column(getattr(self._jc, "as")(_to_seq(sc, list(alias)))) name = copy_func(alias, sinceversion=2.0, doc=":func:`name` is an alias for :func:`alias`.") http://git-wip-us.apache.org/repos/asf/spark/blob/7b64f7aa/python/pyspark/sql/tests.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 7372167..62e1a8c 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -266,6 +266,9 @@ class SQLTests(ReusedPySparkTestCase): self.assertEqual(result[0][0], "a") self.assertEqual(result[0][1], "b") + with self.assertRaises(ValueError): + data.select(explode(data.mapfield).alias("a", "b", metadata={'max': 99})).count() + def test_and_in_expression(self): self.assertEqual(4, self.df.filter((self.df.key <= 10) & (self.df.value <= "2")).count()) self.assertRaises(ValueError, lambda: (self.df.key <= 10) and (self.df.value <= "2")) @@ -895,6 +898,13 @@ class SQLTests(ReusedPySparkTestCase): self.assertEqual(self.testData, df.select(df.key, df.value).collect()) self.assertEqual([Row(value='1')], df.where(df.key == 1).select(df.value).collect()) + def test_column_alias_metadata(self): + df = self.df + df_with_meta = df.select(df.key.alias('pk', metadata={'label': 'Primary Key'})) + self.assertEqual(df_with_meta.schema['pk'].metadata['label'], 'Primary Key') + with self.assertRaises(AssertionError): + df.select(df.key.alias('pk', metdata={'label': 'Primary Key'})) + def test_freqItems(self): vals = [Row(a=1, b=-2.0) if i % 2 == 0 else Row(a=i, b=i * 1.0) for i in range(100)] df = self.sc.parallelize(vals).toDF() --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org