Updated Branches: refs/heads/branch-0.9 0f60ef2c4 -> 5edbd175e
Merge pull request #523 from JoshRosen/SPARK-1043 Switch from MUTF8 to UTF8 in PySpark serializers. This fixes SPARK-1043, a bug introduced in 0.9.0 where PySpark couldn't serialize strings > 64kB. This fix was written by @tyro89 and @bouk in #512. This commit squashes and rebases their pull request in order to fix some merge conflicts. (cherry picked from commit f8c742ce274fbae2a9e616d4c97469b6a22069bb) Signed-off-by: Patrick Wendell <pwend...@gmail.com> Project: http://git-wip-us.apache.org/repos/asf/incubator-spark/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-spark/commit/5edbd175 Tree: http://git-wip-us.apache.org/repos/asf/incubator-spark/tree/5edbd175 Diff: http://git-wip-us.apache.org/repos/asf/incubator-spark/diff/5edbd175 Branch: refs/heads/branch-0.9 Commit: 5edbd175e07dc9704b1babb9c5e8d97fb644be65 Parents: 0f60ef2 Author: Josh Rosen <joshro...@apache.org> Authored: Tue Jan 28 21:30:20 2014 -0800 Committer: Patrick Wendell <pwend...@gmail.com> Committed: Tue Jan 28 21:32:58 2014 -0800 ---------------------------------------------------------------------- .../org/apache/spark/api/python/PythonRDD.scala | 18 +++++++--- .../spark/api/python/PythonRDDSuite.scala | 35 ++++++++++++++++++++ python/pyspark/context.py | 4 +-- python/pyspark/serializers.py | 6 ++-- python/pyspark/worker.py | 8 ++--- 5 files changed, 57 insertions(+), 14 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/5edbd175/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 57bde8d..46d53e3 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -62,7 +62,7 @@ private[spark] class PythonRDD[T: ClassTag]( // Partition index dataOut.writeInt(split.index) // sparkFilesDir - dataOut.writeUTF(SparkFiles.getRootDirectory) + PythonRDD.writeUTF(SparkFiles.getRootDirectory, dataOut) // Broadcast variables dataOut.writeInt(broadcastVars.length) for (broadcast <- broadcastVars) { @@ -72,7 +72,9 @@ private[spark] class PythonRDD[T: ClassTag]( } // Python includes (*.zip and *.egg files) dataOut.writeInt(pythonIncludes.length) - pythonIncludes.foreach(dataOut.writeUTF) + for (include <- pythonIncludes) { + PythonRDD.writeUTF(include, dataOut) + } dataOut.flush() // Serialized command: dataOut.writeInt(command.length) @@ -219,7 +221,7 @@ private[spark] object PythonRDD { } case string: String => newIter.asInstanceOf[Iterator[String]].foreach { str => - dataOut.writeUTF(str) + writeUTF(str, dataOut) } case pair: Tuple2[_, _] => pair._1 match { @@ -232,8 +234,8 @@ private[spark] object PythonRDD { } case stringPair: String => newIter.asInstanceOf[Iterator[Tuple2[String, String]]].foreach { pair => - dataOut.writeUTF(pair._1) - dataOut.writeUTF(pair._2) + writeUTF(pair._1, dataOut) + writeUTF(pair._2, dataOut) } case other => throw new SparkException("Unexpected Tuple2 element type " + pair._1.getClass) @@ -244,6 +246,12 @@ private[spark] object PythonRDD { } } + def writeUTF(str: String, dataOut: DataOutputStream) { + val bytes = str.getBytes("UTF-8") + dataOut.writeInt(bytes.length) + dataOut.write(bytes) + } + def writeToFile[T](items: java.util.Iterator[T], filename: String) { import scala.collection.JavaConverters._ writeToFile(items.asScala, filename) http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/5edbd175/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala ---------------------------------------------------------------------- diff --git a/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala b/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala new file mode 100644 index 0000000..1bebfe5 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.python + +import org.scalatest.FunSuite +import org.scalatest.matchers.ShouldMatchers +import org.apache.spark.api.python.PythonRDD + +import java.io.{ByteArrayOutputStream, DataOutputStream} + +class PythonRDDSuite extends FunSuite { + + test("Writing large strings to the worker") { + val input: List[String] = List("a"*100000) + val buffer = new DataOutputStream(new ByteArrayOutputStream) + PythonRDD.writeIteratorToStream(input.iterator, buffer) + } + +} + http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/5edbd175/python/pyspark/context.py ---------------------------------------------------------------------- diff --git a/python/pyspark/context.py b/python/pyspark/context.py index f955aad..f318b5d 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -27,7 +27,7 @@ from pyspark.broadcast import Broadcast from pyspark.conf import SparkConf from pyspark.files import SparkFiles from pyspark.java_gateway import launch_gateway -from pyspark.serializers import PickleSerializer, BatchedSerializer, MUTF8Deserializer +from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer from pyspark.storagelevel import StorageLevel from pyspark.rdd import RDD @@ -234,7 +234,7 @@ class SparkContext(object): """ minSplits = minSplits or min(self.defaultParallelism, 2) return RDD(self._jsc.textFile(name, minSplits), self, - MUTF8Deserializer()) + UTF8Deserializer()) def _checkpointFile(self, name, input_deserializer): jrdd = self._jsc.checkpointFile(name) http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/5edbd175/python/pyspark/serializers.py ---------------------------------------------------------------------- diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 2a500ab..8c6ad79 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -261,13 +261,13 @@ class MarshalSerializer(FramedSerializer): loads = marshal.loads -class MUTF8Deserializer(Serializer): +class UTF8Deserializer(Serializer): """ - Deserializes streams written by Java's DataOutputStream.writeUTF(). + Deserializes streams written by getBytes. """ def loads(self, stream): - length = struct.unpack('>H', stream.read(2))[0] + length = read_int(stream) return stream.read(length).decode('utf8') def load_stream(self, stream): http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/5edbd175/python/pyspark/worker.py ---------------------------------------------------------------------- diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index d77981f..4be4063 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -30,11 +30,11 @@ from pyspark.broadcast import Broadcast, _broadcastRegistry from pyspark.cloudpickle import CloudPickler from pyspark.files import SparkFiles from pyspark.serializers import write_with_length, write_int, read_long, \ - write_long, read_int, SpecialLengths, MUTF8Deserializer, PickleSerializer + write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer pickleSer = PickleSerializer() -mutf8_deserializer = MUTF8Deserializer() +utf8_deserializer = UTF8Deserializer() def report_times(outfile, boot, init, finish): @@ -51,7 +51,7 @@ def main(infile, outfile): return # fetch name of workdir - spark_files_dir = mutf8_deserializer.loads(infile) + spark_files_dir = utf8_deserializer.loads(infile) SparkFiles._root_directory = spark_files_dir SparkFiles._is_running_on_worker = True @@ -66,7 +66,7 @@ def main(infile, outfile): sys.path.append(spark_files_dir) # *.py files that were added will be copied here num_python_includes = read_int(infile) for _ in range(num_python_includes): - filename = mutf8_deserializer.loads(infile) + filename = utf8_deserializer.loads(infile) sys.path.append(os.path.join(spark_files_dir, filename)) command = pickleSer._read_with_length(infile)