Repository: spark
Updated Branches:
  refs/heads/master 8bd05c9db -> 2b8906c43


[SPARK-14739][PYSPARK] Fix Vectors parser bugs

## What changes were proposed in this pull request?

The PySpark deserialization has a bug that shows while deserializing all zero 
sparse vectors. This fix filters out empty string tokens before casting, hence 
properly stringified SparseVectors successfully get parsed.

## How was this patch tested?

Standard unit-tests similar to other methods.

Author: Arash Parsa <arash@ip-192-168-50-106.ec2.internal>
Author: Arash Parsa <aras...@gmail.com>
Author: Vishnu Prasad <vishnu...@gmail.com>
Author: Vishnu Prasad S <vishnu...@gmail.com>

Closes #12516 from arashpa/SPARK-14739.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/2b8906c4
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/2b8906c4
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/2b8906c4

Branch: refs/heads/master
Commit: 2b8906c43760591f2e2da99bf0e34fa1bb63bfd1
Parents: 8bd05c9
Author: Arash Parsa <arash@ip-192-168-50-106.ec2.internal>
Authored: Thu Apr 21 11:29:24 2016 +0100
Committer: Sean Owen <so...@cloudera.com>
Committed: Thu Apr 21 11:29:24 2016 +0100

----------------------------------------------------------------------
 python/pyspark/mllib/linalg/__init__.py |  6 +++---
 python/pyspark/mllib/tests.py           | 16 +++++++++++-----
 2 files changed, 14 insertions(+), 8 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/2b8906c4/python/pyspark/mllib/linalg/__init__.py
----------------------------------------------------------------------
diff --git a/python/pyspark/mllib/linalg/__init__.py 
b/python/pyspark/mllib/linalg/__init__.py
index abf00a4..4cd7306 100644
--- a/python/pyspark/mllib/linalg/__init__.py
+++ b/python/pyspark/mllib/linalg/__init__.py
@@ -293,7 +293,7 @@ class DenseVector(Vector):
         s = s[start + 1: end]
 
         try:
-            values = [float(val) for val in s.split(',')]
+            values = [float(val) for val in s.split(',') if val]
         except ValueError:
             raise ValueError("Unable to parse values from %s" % s)
         return DenseVector(values)
@@ -586,7 +586,7 @@ class SparseVector(Vector):
         new_s = s[ind_start + 1: ind_end]
         ind_list = new_s.split(',')
         try:
-            indices = [int(ind) for ind in ind_list]
+            indices = [int(ind) for ind in ind_list if ind]
         except ValueError:
             raise ValueError("Unable to parse indices from %s." % new_s)
         s = s[ind_end + 1:].strip()
@@ -599,7 +599,7 @@ class SparseVector(Vector):
             raise ValueError("Values array should end with ']'.")
         val_list = s[val_start + 1: val_end].split(',')
         try:
-            values = [float(val) for val in val_list]
+            values = [float(val) for val in val_list if val]
         except ValueError:
             raise ValueError("Unable to parse values from %s." % s)
         return SparseVector(size, indices, values)

http://git-wip-us.apache.org/repos/asf/spark/blob/2b8906c4/python/pyspark/mllib/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py
index f272da5..53a1d2c 100644
--- a/python/pyspark/mllib/tests.py
+++ b/python/pyspark/mllib/tests.py
@@ -393,14 +393,20 @@ class VectorTests(MLlibTestCase):
         self.assertTrue(array_equal(sm.values, [1, 3, 4, 6, 9]))
 
     def test_parse_vector(self):
+        a = DenseVector([])
+        self.assertEqual(str(a), '[]')
+        self.assertEqual(Vectors.parse(str(a)), a)
         a = DenseVector([3, 4, 6, 7])
-        self.assertTrue(str(a), '[3.0,4.0,6.0,7.0]')
-        self.assertTrue(Vectors.parse(str(a)), a)
+        self.assertEqual(str(a), '[3.0,4.0,6.0,7.0]')
+        self.assertEqual(Vectors.parse(str(a)), a)
+        a = SparseVector(4, [], [])
+        self.assertEqual(str(a), '(4,[],[])')
+        self.assertEqual(SparseVector.parse(str(a)), a)
         a = SparseVector(4, [0, 2], [3, 4])
-        self.assertTrue(str(a), '(4,[0,2],[3.0,4.0])')
-        self.assertTrue(Vectors.parse(str(a)), a)
+        self.assertEqual(str(a), '(4,[0,2],[3.0,4.0])')
+        self.assertEqual(Vectors.parse(str(a)), a)
         a = SparseVector(10, [0, 1], [4, 5])
-        self.assertTrue(SparseVector.parse(' (10, [0,1 ],[ 4.0,5.0] )'), a)
+        self.assertEqual(SparseVector.parse(' (10, [0,1 ],[ 4.0,5.0] )'), a)
 
     def test_norms(self):
         a = DenseVector([0, 2, 3, -1])


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to