Repository: spark
Updated Branches:
refs/heads/master 7f3c778fd -> 6a5a7254d
[SPARK-18667][PYSPARK][SQL] Change the way to group row in BatchEvalPythonExec
so input_file_name function can work with UDF in pyspark
## What changes were proposed in this pull request?
`input_file_name` doesn't return filename when working with UDF in PySpark. An
example shows the problem:
from pyspark.sql.functions import *
from pyspark.sql.types import *
def filename(path):
return path
sourceFile = udf(filename, StringType())
spark.read.json("tmp.json").select(sourceFile(input_file_name())).show()
+---------------------------+
|filename(input_file_name())|
+---------------------------+
| |
+---------------------------+
The cause of this issue is, we group rows in `BatchEvalPythonExec` for batching
processing of PythonUDF. Currently we group rows first and then evaluate
expressions on the rows. If the data is less than the required number of rows
for a group, the iterator will be consumed to the end before the evaluation.
However, once the iterator reaches the end, we will unset input filename. So
the input_file_name expression can't return correct filename.
This patch fixes the approach to group the batch of rows. We evaluate the
expression first and then group evaluated results to batch.
## How was this patch tested?
Added unit test to PySpark.
Please review http://spark.apache.org/contributing.html before opening a pull
request.
Author: Liang-Chi Hsieh <[email protected]>
Closes #16115 from viirya/fix-py-udf-input-filename.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/6a5a7254
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/6a5a7254
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/6a5a7254
Branch: refs/heads/master
Commit: 6a5a7254dc37952505989e9e580a14543adb730c
Parents: 7f3c778
Author: Liang-Chi Hsieh <[email protected]>
Authored: Thu Dec 8 23:22:18 2016 +0800
Committer: Wenchen Fan <[email protected]>
Committed: Thu Dec 8 23:22:18 2016 +0800
----------------------------------------------------------------------
python/pyspark/sql/tests.py | 8 +++++
.../execution/python/BatchEvalPythonExec.scala | 35 +++++++++-----------
2 files changed, 24 insertions(+), 19 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/6a5a7254/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 50df68b..66320bd 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -412,6 +412,14 @@ class SQLTests(ReusedPySparkTestCase):
res.explain(True)
self.assertEqual(res.collect(), [Row(id=0, copy=0)])
+ def test_udf_with_input_file_name(self):
+ from pyspark.sql.functions import udf, input_file_name
+ from pyspark.sql.types import StringType
+ sourceFile = udf(lambda path: path, StringType())
+ filePath = "python/test_support/sql/people1.json"
+ row =
self.spark.read.json(filePath).select(sourceFile(input_file_name())).first()
+ self.assertTrue(row[0].find("people1.json") != -1)
+
def test_basic_functions(self):
rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
df = self.spark.read.json(rdd)
http://git-wip-us.apache.org/repos/asf/spark/blob/6a5a7254/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala
----------------------------------------------------------------------
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala
index dcaf2c7..7a5ac48 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala
@@ -119,26 +119,23 @@ case class BatchEvalPythonExec(udfs: Seq[PythonUDF],
output: Seq[Attribute], chi
val pickle = new Pickler(needConversion)
// Input iterator to Python: input rows are grouped so we send them in
batches to Python.
// For each row, add it to the queue.
- val inputIterator = iter.grouped(100).map { inputRows =>
- val toBePickled = inputRows.map { inputRow =>
- queue.add(inputRow.asInstanceOf[UnsafeRow])
- val row = projection(inputRow)
- if (needConversion) {
- EvaluatePython.toJava(row, schema)
- } else {
- // fast path for these types that does not need conversion in
Python
- val fields = new Array[Any](row.numFields)
- var i = 0
- while (i < row.numFields) {
- val dt = dataTypes(i)
- fields(i) = EvaluatePython.toJava(row.get(i, dt), dt)
- i += 1
- }
- fields
+ val inputIterator = iter.map { inputRow =>
+ queue.add(inputRow.asInstanceOf[UnsafeRow])
+ val row = projection(inputRow)
+ if (needConversion) {
+ EvaluatePython.toJava(row, schema)
+ } else {
+ // fast path for these types that does not need conversion in Python
+ val fields = new Array[Any](row.numFields)
+ var i = 0
+ while (i < row.numFields) {
+ val dt = dataTypes(i)
+ fields(i) = EvaluatePython.toJava(row.get(i, dt), dt)
+ i += 1
}
- }.toArray
- pickle.dumps(toBePickled)
- }
+ fields
+ }
+ }.grouped(100).map(x => pickle.dumps(x.toArray))
val context = TaskContext.get()
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]