Repository: spark Updated Branches: refs/heads/master 8bd05c9db -> 2b8906c43
[SPARK-14739][PYSPARK] Fix Vectors parser bugs ## What changes were proposed in this pull request? The PySpark deserialization has a bug that shows while deserializing all zero sparse vectors. This fix filters out empty string tokens before casting, hence properly stringified SparseVectors successfully get parsed. ## How was this patch tested? Standard unit-tests similar to other methods. Author: Arash Parsa <arash@ip-192-168-50-106.ec2.internal> Author: Arash Parsa <aras...@gmail.com> Author: Vishnu Prasad <vishnu...@gmail.com> Author: Vishnu Prasad S <vishnu...@gmail.com> Closes #12516 from arashpa/SPARK-14739. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/2b8906c4 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/2b8906c4 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/2b8906c4 Branch: refs/heads/master Commit: 2b8906c43760591f2e2da99bf0e34fa1bb63bfd1 Parents: 8bd05c9 Author: Arash Parsa <arash@ip-192-168-50-106.ec2.internal> Authored: Thu Apr 21 11:29:24 2016 +0100 Committer: Sean Owen <so...@cloudera.com> Committed: Thu Apr 21 11:29:24 2016 +0100 ---------------------------------------------------------------------- python/pyspark/mllib/linalg/__init__.py | 6 +++--- python/pyspark/mllib/tests.py | 16 +++++++++++----- 2 files changed, 14 insertions(+), 8 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/2b8906c4/python/pyspark/mllib/linalg/__init__.py ---------------------------------------------------------------------- diff --git a/python/pyspark/mllib/linalg/__init__.py b/python/pyspark/mllib/linalg/__init__.py index abf00a4..4cd7306 100644 --- a/python/pyspark/mllib/linalg/__init__.py +++ b/python/pyspark/mllib/linalg/__init__.py @@ -293,7 +293,7 @@ class DenseVector(Vector): s = s[start + 1: end] try: - values = [float(val) for val in s.split(',')] + values = [float(val) for val in s.split(',') if val] except ValueError: raise ValueError("Unable to parse values from %s" % s) return DenseVector(values) @@ -586,7 +586,7 @@ class SparseVector(Vector): new_s = s[ind_start + 1: ind_end] ind_list = new_s.split(',') try: - indices = [int(ind) for ind in ind_list] + indices = [int(ind) for ind in ind_list if ind] except ValueError: raise ValueError("Unable to parse indices from %s." % new_s) s = s[ind_end + 1:].strip() @@ -599,7 +599,7 @@ class SparseVector(Vector): raise ValueError("Values array should end with ']'.") val_list = s[val_start + 1: val_end].split(',') try: - values = [float(val) for val in val_list] + values = [float(val) for val in val_list if val] except ValueError: raise ValueError("Unable to parse values from %s." % s) return SparseVector(size, indices, values) http://git-wip-us.apache.org/repos/asf/spark/blob/2b8906c4/python/pyspark/mllib/tests.py ---------------------------------------------------------------------- diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index f272da5..53a1d2c 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -393,14 +393,20 @@ class VectorTests(MLlibTestCase): self.assertTrue(array_equal(sm.values, [1, 3, 4, 6, 9])) def test_parse_vector(self): + a = DenseVector([]) + self.assertEqual(str(a), '[]') + self.assertEqual(Vectors.parse(str(a)), a) a = DenseVector([3, 4, 6, 7]) - self.assertTrue(str(a), '[3.0,4.0,6.0,7.0]') - self.assertTrue(Vectors.parse(str(a)), a) + self.assertEqual(str(a), '[3.0,4.0,6.0,7.0]') + self.assertEqual(Vectors.parse(str(a)), a) + a = SparseVector(4, [], []) + self.assertEqual(str(a), '(4,[],[])') + self.assertEqual(SparseVector.parse(str(a)), a) a = SparseVector(4, [0, 2], [3, 4]) - self.assertTrue(str(a), '(4,[0,2],[3.0,4.0])') - self.assertTrue(Vectors.parse(str(a)), a) + self.assertEqual(str(a), '(4,[0,2],[3.0,4.0])') + self.assertEqual(Vectors.parse(str(a)), a) a = SparseVector(10, [0, 1], [4, 5]) - self.assertTrue(SparseVector.parse(' (10, [0,1 ],[ 4.0,5.0] )'), a) + self.assertEqual(SparseVector.parse(' (10, [0,1 ],[ 4.0,5.0] )'), a) def test_norms(self): a = DenseVector([0, 2, 3, -1]) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org