Repository: incubator-systemml
Updated Branches:
  refs/heads/master 10b7b8669 -> 67f16c46e


[SYSTEMML-1232] Migrate stringDataFrameToVectorDataFrame to ml Vector

Restore and migrate RDDConverterUtilsExt.stringDataFrameToVectorDataFrame
method from mllib Vector class to ml Vector class. Use NumericParser since
ml.linalg.Vectors.parse() does not exist.

Closes #379.


Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: 
http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/67f16c46
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/67f16c46
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/67f16c46

Branch: refs/heads/master
Commit: 67f16c46e692adfe2533cc31103374d6e5d39bb3
Parents: 10b7b86
Author: Deron Eriksson <[email protected]>
Authored: Sun Feb 5 14:21:04 2017 -0800
Committer: Deron Eriksson <[email protected]>
Committed: Fri Feb 10 14:05:29 2017 -0800

----------------------------------------------------------------------
 pom.xml                                         |   1 +
 .../spark/utils/RDDConverterUtilsExt.java       | 112 +++++++++++--
 .../conversion/RDDConverterUtilsExtTest.java    | 160 +++++++++++++++++++
 .../integration/conversion/ZPackageSuite.java   |  36 +++++
 4 files changed, 296 insertions(+), 13 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/67f16c46/pom.xml
----------------------------------------------------------------------
diff --git a/pom.xml b/pom.xml
index ab088c8..f81557e 100644
--- a/pom.xml
+++ b/pom.xml
@@ -372,6 +372,7 @@
 
                                        <includes>
                                                
<include>**/integration/applications/**/*Suite.java</include>
+                                               
<include>**/integration/conversion/*Suite.java</include>
                                                
<include>**/integration/functions/data/*Suite.java</include>
                                                
<include>**/integration/functions/gdfo/*Suite.java</include>
                                                
<include>**/integration/functions/sparse/*Suite.java</include>

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/67f16c46/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDConverterUtilsExt.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDConverterUtilsExt.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDConverterUtilsExt.java
index e0d347f..e3b4d0c 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDConverterUtilsExt.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDConverterUtilsExt.java
@@ -21,15 +21,14 @@ package org.apache.sysml.runtime.instructions.spark.utils;
 
 import java.io.IOException;
 import java.io.Serializable;
+import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
 import java.util.ArrayList;
-import java.util.Arrays;
 import java.util.Iterator;
-import java.util.List;
-import java.util.Scanner;
 
 import org.apache.hadoop.io.Text;
-import org.apache.spark.Accumulator;
 import org.apache.spark.SparkContext;
+import org.apache.spark.SparkException;
 import org.apache.spark.api.java.JavaPairRDD;
 import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.api.java.JavaSparkContext;
@@ -40,6 +39,7 @@ import org.apache.spark.ml.linalg.VectorUDT;
 import org.apache.spark.ml.linalg.Vectors;
 import org.apache.spark.mllib.linalg.distributed.CoordinateMatrix;
 import org.apache.spark.mllib.linalg.distributed.MatrixEntry;
+import org.apache.spark.mllib.util.NumericParser;
 import org.apache.spark.sql.Dataset;
 import org.apache.spark.sql.Row;
 import org.apache.spark.sql.RowFactory;
@@ -47,15 +47,7 @@ import org.apache.spark.sql.SQLContext;
 import org.apache.spark.sql.types.DataTypes;
 import org.apache.spark.sql.types.StructField;
 import org.apache.spark.sql.types.StructType;
-
-import java.nio.ByteBuffer;
-import java.nio.ByteOrder;
-
-import scala.Tuple2;
-
 import org.apache.sysml.runtime.DMLRuntimeException;
-import 
org.apache.sysml.runtime.instructions.spark.functions.ConvertMatrixBlockToIJVLines;
-import org.apache.sysml.runtime.io.IOUtilFunctions;
 import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
 import org.apache.sysml.runtime.matrix.data.MatrixBlock;
 import org.apache.sysml.runtime.matrix.data.MatrixCell;
@@ -63,7 +55,8 @@ import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
 import org.apache.sysml.runtime.matrix.mapred.IndexedMatrixValue;
 import org.apache.sysml.runtime.matrix.mapred.ReblockBuffer;
 import org.apache.sysml.runtime.util.FastStringTokenizer;
-import org.apache.sysml.runtime.util.UtilFunctions;
+
+import scala.Tuple2;
 
 /**
  * NOTE: These are experimental converter utils. Once thoroughly tested, they
@@ -362,4 +355,97 @@ public class RDDConverterUtilsExt
                        ret.addAll(SparkUtils.fromIndexedMatrixBlock(rettmp));
                }
        }
+
+       /**
+        * Convert a dataframe of comma-separated string rows to a dataframe of
+        * ml.linalg.Vector rows.
+        * 
+        * <p>
+        * Example input rows:<br>
+        * 
+        * <code>
+        * ((1.2, 4.3, 3.4))<br>
+        * (1.2, 3.4, 2.2)<br>
+        * [[1.2, 34.3, 1.2, 1.25]]<br>
+        * [1.2, 3.4]<br>
+        * </code>
+        * 
+        * @param sqlContext
+        *            Spark SQL Context
+        * @param inputDF
+        *            dataframe of comma-separated row strings to convert to
+        *            dataframe of ml.linalg.Vector rows
+        * @return dataframe of ml.linalg.Vector rows
+        * @throws DMLRuntimeException
+        *             if DMLRuntimeException occurs
+        */
+       public static Dataset<Row> stringDataFrameToVectorDataFrame(SQLContext 
sqlContext, Dataset<Row> inputDF)
+                       throws DMLRuntimeException {
+
+               StructField[] oldSchema = inputDF.schema().fields();
+               StructField[] newSchema = new StructField[oldSchema.length];
+               for (int i = 0; i < oldSchema.length; i++) {
+                       String colName = oldSchema[i].name();
+                       newSchema[i] = DataTypes.createStructField(colName, new 
VectorUDT(), true);
+               }
+
+               // converter
+               class StringToVector implements Function<Tuple2<Row, Long>, 
Row> {
+                       private static final long serialVersionUID = 
-4733816995375745659L;
+
+                       @Override
+                       public Row call(Tuple2<Row, Long> arg0) throws 
Exception {
+                               Row oldRow = arg0._1;
+                               int oldNumCols = oldRow.length();
+                               if (oldNumCols > 1) {
+                                       throw new DMLRuntimeException("The row 
must have at most one column");
+                               }
+
+                               // parse the various strings. i.e
+                               // ((1.2, 4.3, 3.4)) or (1.2, 3.4, 2.2)
+                               // [[1.2, 34.3, 1.2, 1.2]] or [1.2, 3.4]
+                               Object[] fields = new Object[oldNumCols];
+                               ArrayList<Object> fieldsArr = new 
ArrayList<Object>();
+                               for (int i = 0; i < oldRow.length(); i++) {
+                                       Object ci = oldRow.get(i);
+                                       if (ci == null) {
+                                               fieldsArr.add(null);
+                                       } else if (ci instanceof String) {
+                                               String cis = (String) ci;
+                                               StringBuffer sb = new 
StringBuffer(cis.trim());
+                                               for (int nid = 0; i < 2; i++) { 
// remove two level
+                                                                               
                                // nesting
+                                                       if ((sb.charAt(0) == 
'(' && sb.charAt(sb.length() - 1) == ')')
+                                                                       || 
(sb.charAt(0) == '[' && sb.charAt(sb.length() - 1) == ']')) {
+                                                               
sb.deleteCharAt(0);
+                                                               
sb.setLength(sb.length() - 1);
+                                                       }
+                                               }
+                                               // have the replace code
+                                               String ncis = "[" + 
sb.toString().replaceAll(" *, *", ",") + "]";
+
+                                               try {
+                                                       // ncis [ ] will always 
result in double array return type
+                                                       double[] doubles = 
(double[]) NumericParser.parse(ncis);
+                                                       Vector dense = 
Vectors.dense(doubles);
+                                                       fieldsArr.add(dense);
+                                               } catch (Exception e) { // 
can't catch SparkException here in Java apparently
+                                                       throw new 
DMLRuntimeException("Error converting to double array. " + e.getMessage(), e);
+                                               }
+
+                                       } else {
+                                               throw new 
DMLRuntimeException("Only String is supported");
+                                       }
+                               }
+                               Row row = 
RowFactory.create(fieldsArr.toArray());
+                               return row;
+                       }
+               }
+
+               // output DF
+               JavaRDD<Row> newRows = 
inputDF.rdd().toJavaRDD().zipWithIndex().map(new StringToVector());
+               Dataset<Row> outDF = sqlContext.createDataFrame(newRows.rdd(), 
DataTypes.createStructType(newSchema));
+
+               return outDF;
+       }
 }

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/67f16c46/src/test/java/org/apache/sysml/test/integration/conversion/RDDConverterUtilsExtTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/conversion/RDDConverterUtilsExtTest.java
 
b/src/test/java/org/apache/sysml/test/integration/conversion/RDDConverterUtilsExtTest.java
new file mode 100644
index 0000000..7a69423
--- /dev/null
+++ 
b/src/test/java/org/apache/sysml/test/integration/conversion/RDDConverterUtilsExtTest.java
@@ -0,0 +1,160 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.test.integration.conversion;
+
+import static org.junit.Assert.assertTrue;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.SparkException;
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.api.java.function.Function;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.RowFactory;
+import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.types.DataTypes;
+import org.apache.spark.sql.types.StructField;
+import org.apache.spark.sql.types.StructType;
+import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
+import org.apache.sysml.runtime.instructions.spark.utils.RDDConverterUtilsExt;
+import org.apache.sysml.test.integration.AutomatedTestBase;
+import org.junit.After;
+import org.junit.AfterClass;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+public class RDDConverterUtilsExtTest extends AutomatedTestBase {
+
+       private static SparkConf conf;
+       private static JavaSparkContext sc;
+
+       @BeforeClass
+       public static void setUpClass() {
+               if (conf == null)
+                       conf = 
SparkExecutionContext.createSystemMLSparkConf().setAppName("RDDConverterUtilsExtTest")
+                                       .setMaster("local");
+               if (sc == null)
+                       sc = new JavaSparkContext(conf);
+       }
+
+       @Override
+       public void setUp() {
+               // no setup required
+       }
+
+       /**
+        * Convert a basic String to a spark.sql.Row.
+        */
+       static class StringToRow implements Function<String, Row> {
+               private static final long serialVersionUID = 
3945939649355731805L;
+
+               @Override
+               public Row call(String str) throws Exception {
+                       return RowFactory.create(str);
+               }
+       }
+
+       @Test
+       public void testStringDataFrameToVectorDataFrame() throws 
DMLRuntimeException {
+               List<String> list = new ArrayList<String>();
+               list.add("((1.2, 4.3, 3.4))");
+               list.add("(1.2, 3.4, 2.2)");
+               list.add("[[1.2, 34.3, 1.2, 1.25]]");
+               list.add("[1.2, 3.4]");
+               JavaRDD<String> javaRddString = sc.parallelize(list);
+               JavaRDD<Row> javaRddRow = javaRddString.map(new StringToRow());
+               SQLContext sqlContext = new SQLContext(sc);
+               List<StructField> fields = new ArrayList<StructField>();
+               fields.add(DataTypes.createStructField("C1", 
DataTypes.StringType, true));
+               StructType schema = DataTypes.createStructType(fields);
+               Dataset<Row> inDF = sqlContext.createDataFrame(javaRddRow, 
schema);
+               Dataset<Row> outDF = 
RDDConverterUtilsExt.stringDataFrameToVectorDataFrame(sqlContext, inDF);
+
+               List<String> expectedResults = new ArrayList<String>();
+               expectedResults.add("[[1.2,4.3,3.4]]");
+               expectedResults.add("[[1.2,3.4,2.2]]");
+               expectedResults.add("[[1.2,34.3,1.2,1.25]]");
+               expectedResults.add("[[1.2,3.4]]");
+
+               List<Row> outputList = outDF.collectAsList();
+               for (Row row : outputList) {
+                       assertTrue("Expected results don't contain: " + row, 
expectedResults.contains(row.toString()));
+               }
+       }
+
+       @Test
+       public void testStringDataFrameToVectorDataFrameNull() throws 
DMLRuntimeException {
+               List<String> list = new ArrayList<String>();
+               list.add("[1.2, 3.4]");
+               list.add(null);
+               JavaRDD<String> javaRddString = sc.parallelize(list);
+               JavaRDD<Row> javaRddRow = javaRddString.map(new StringToRow());
+               SQLContext sqlContext = new SQLContext(sc);
+               List<StructField> fields = new ArrayList<StructField>();
+               fields.add(DataTypes.createStructField("C1", 
DataTypes.StringType, true));
+               StructType schema = DataTypes.createStructType(fields);
+               Dataset<Row> inDF = sqlContext.createDataFrame(javaRddRow, 
schema);
+               Dataset<Row> outDF = 
RDDConverterUtilsExt.stringDataFrameToVectorDataFrame(sqlContext, inDF);
+
+               List<String> expectedResults = new ArrayList<String>();
+               expectedResults.add("[[1.2,3.4]]");
+               expectedResults.add("[null]");
+
+               List<Row> outputList = outDF.collectAsList();
+               for (Row row : outputList) {
+                       assertTrue("Expected results don't contain: " + row, 
expectedResults.contains(row.toString()));
+               }
+       }
+
+       @Test(expected = SparkException.class)
+       public void testStringDataFrameToVectorDataFrameNonNumbers() throws 
DMLRuntimeException {
+               List<String> list = new ArrayList<String>();
+               list.add("[cheeseburger,fries]");
+               JavaRDD<String> javaRddString = sc.parallelize(list);
+               JavaRDD<Row> javaRddRow = javaRddString.map(new StringToRow());
+               SQLContext sqlContext = new SQLContext(sc);
+               List<StructField> fields = new ArrayList<StructField>();
+               fields.add(DataTypes.createStructField("C1", 
DataTypes.StringType, true));
+               StructType schema = DataTypes.createStructType(fields);
+               Dataset<Row> inDF = sqlContext.createDataFrame(javaRddRow, 
schema);
+               Dataset<Row> outDF = 
RDDConverterUtilsExt.stringDataFrameToVectorDataFrame(sqlContext, inDF);
+               // trigger evaluation to throw exception
+               outDF.collectAsList();
+       }
+
+       @After
+       public void tearDown() {
+               super.tearDown();
+       }
+
+       @AfterClass
+       public static void tearDownClass() {
+               // stop spark context to allow single jvm tests (otherwise the
+               // next test that tries to create a SparkContext would fail)
+               sc.stop();
+               sc = null;
+               conf = null;
+       }
+}

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/67f16c46/src/test_suites/java/org/apache/sysml/test/integration/conversion/ZPackageSuite.java
----------------------------------------------------------------------
diff --git 
a/src/test_suites/java/org/apache/sysml/test/integration/conversion/ZPackageSuite.java
 
b/src/test_suites/java/org/apache/sysml/test/integration/conversion/ZPackageSuite.java
new file mode 100644
index 0000000..b8ab13d
--- /dev/null
+++ 
b/src/test_suites/java/org/apache/sysml/test/integration/conversion/ZPackageSuite.java
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.test.integration.conversion;
+
+import org.junit.runner.RunWith;
+import org.junit.runners.Suite;
+
+/** Group together the tests in this package/related subpackages into a single 
suite so that the Maven build
+ *  won't run two of them at once. */
+@RunWith(Suite.class)
[email protected]({
+  org.apache.sysml.test.integration.conversion.RDDConverterUtilsExtTest.class
+})
+
+
+/** This class is just a holder for the above JUnit annotations. */
+public class ZPackageSuite {
+
+}

Reply via email to