Repository: incubator-systemml Updated Branches: refs/heads/master 10b7b8669 -> 67f16c46e
[SYSTEMML-1232] Migrate stringDataFrameToVectorDataFrame to ml Vector Restore and migrate RDDConverterUtilsExt.stringDataFrameToVectorDataFrame method from mllib Vector class to ml Vector class. Use NumericParser since ml.linalg.Vectors.parse() does not exist. Closes #379. Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/67f16c46 Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/67f16c46 Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/67f16c46 Branch: refs/heads/master Commit: 67f16c46e692adfe2533cc31103374d6e5d39bb3 Parents: 10b7b86 Author: Deron Eriksson <[email protected]> Authored: Sun Feb 5 14:21:04 2017 -0800 Committer: Deron Eriksson <[email protected]> Committed: Fri Feb 10 14:05:29 2017 -0800 ---------------------------------------------------------------------- pom.xml | 1 + .../spark/utils/RDDConverterUtilsExt.java | 112 +++++++++++-- .../conversion/RDDConverterUtilsExtTest.java | 160 +++++++++++++++++++ .../integration/conversion/ZPackageSuite.java | 36 +++++ 4 files changed, 296 insertions(+), 13 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/67f16c46/pom.xml ---------------------------------------------------------------------- diff --git a/pom.xml b/pom.xml index ab088c8..f81557e 100644 --- a/pom.xml +++ b/pom.xml @@ -372,6 +372,7 @@ <includes> <include>**/integration/applications/**/*Suite.java</include> + <include>**/integration/conversion/*Suite.java</include> <include>**/integration/functions/data/*Suite.java</include> <include>**/integration/functions/gdfo/*Suite.java</include> <include>**/integration/functions/sparse/*Suite.java</include> http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/67f16c46/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDConverterUtilsExt.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDConverterUtilsExt.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDConverterUtilsExt.java index e0d347f..e3b4d0c 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDConverterUtilsExt.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDConverterUtilsExt.java @@ -21,15 +21,14 @@ package org.apache.sysml.runtime.instructions.spark.utils; import java.io.IOException; import java.io.Serializable; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; import java.util.ArrayList; -import java.util.Arrays; import java.util.Iterator; -import java.util.List; -import java.util.Scanner; import org.apache.hadoop.io.Text; -import org.apache.spark.Accumulator; import org.apache.spark.SparkContext; +import org.apache.spark.SparkException; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; @@ -40,6 +39,7 @@ import org.apache.spark.ml.linalg.VectorUDT; import org.apache.spark.ml.linalg.Vectors; import org.apache.spark.mllib.linalg.distributed.CoordinateMatrix; import org.apache.spark.mllib.linalg.distributed.MatrixEntry; +import org.apache.spark.mllib.util.NumericParser; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; @@ -47,15 +47,7 @@ import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; - -import java.nio.ByteBuffer; -import java.nio.ByteOrder; - -import scala.Tuple2; - import org.apache.sysml.runtime.DMLRuntimeException; -import org.apache.sysml.runtime.instructions.spark.functions.ConvertMatrixBlockToIJVLines; -import org.apache.sysml.runtime.io.IOUtilFunctions; import org.apache.sysml.runtime.matrix.MatrixCharacteristics; import org.apache.sysml.runtime.matrix.data.MatrixBlock; import org.apache.sysml.runtime.matrix.data.MatrixCell; @@ -63,7 +55,8 @@ import org.apache.sysml.runtime.matrix.data.MatrixIndexes; import org.apache.sysml.runtime.matrix.mapred.IndexedMatrixValue; import org.apache.sysml.runtime.matrix.mapred.ReblockBuffer; import org.apache.sysml.runtime.util.FastStringTokenizer; -import org.apache.sysml.runtime.util.UtilFunctions; + +import scala.Tuple2; /** * NOTE: These are experimental converter utils. Once thoroughly tested, they @@ -362,4 +355,97 @@ public class RDDConverterUtilsExt ret.addAll(SparkUtils.fromIndexedMatrixBlock(rettmp)); } } + + /** + * Convert a dataframe of comma-separated string rows to a dataframe of + * ml.linalg.Vector rows. + * + * <p> + * Example input rows:<br> + * + * <code> + * ((1.2, 4.3, 3.4))<br> + * (1.2, 3.4, 2.2)<br> + * [[1.2, 34.3, 1.2, 1.25]]<br> + * [1.2, 3.4]<br> + * </code> + * + * @param sqlContext + * Spark SQL Context + * @param inputDF + * dataframe of comma-separated row strings to convert to + * dataframe of ml.linalg.Vector rows + * @return dataframe of ml.linalg.Vector rows + * @throws DMLRuntimeException + * if DMLRuntimeException occurs + */ + public static Dataset<Row> stringDataFrameToVectorDataFrame(SQLContext sqlContext, Dataset<Row> inputDF) + throws DMLRuntimeException { + + StructField[] oldSchema = inputDF.schema().fields(); + StructField[] newSchema = new StructField[oldSchema.length]; + for (int i = 0; i < oldSchema.length; i++) { + String colName = oldSchema[i].name(); + newSchema[i] = DataTypes.createStructField(colName, new VectorUDT(), true); + } + + // converter + class StringToVector implements Function<Tuple2<Row, Long>, Row> { + private static final long serialVersionUID = -4733816995375745659L; + + @Override + public Row call(Tuple2<Row, Long> arg0) throws Exception { + Row oldRow = arg0._1; + int oldNumCols = oldRow.length(); + if (oldNumCols > 1) { + throw new DMLRuntimeException("The row must have at most one column"); + } + + // parse the various strings. i.e + // ((1.2, 4.3, 3.4)) or (1.2, 3.4, 2.2) + // [[1.2, 34.3, 1.2, 1.2]] or [1.2, 3.4] + Object[] fields = new Object[oldNumCols]; + ArrayList<Object> fieldsArr = new ArrayList<Object>(); + for (int i = 0; i < oldRow.length(); i++) { + Object ci = oldRow.get(i); + if (ci == null) { + fieldsArr.add(null); + } else if (ci instanceof String) { + String cis = (String) ci; + StringBuffer sb = new StringBuffer(cis.trim()); + for (int nid = 0; i < 2; i++) { // remove two level + // nesting + if ((sb.charAt(0) == '(' && sb.charAt(sb.length() - 1) == ')') + || (sb.charAt(0) == '[' && sb.charAt(sb.length() - 1) == ']')) { + sb.deleteCharAt(0); + sb.setLength(sb.length() - 1); + } + } + // have the replace code + String ncis = "[" + sb.toString().replaceAll(" *, *", ",") + "]"; + + try { + // ncis [ ] will always result in double array return type + double[] doubles = (double[]) NumericParser.parse(ncis); + Vector dense = Vectors.dense(doubles); + fieldsArr.add(dense); + } catch (Exception e) { // can't catch SparkException here in Java apparently + throw new DMLRuntimeException("Error converting to double array. " + e.getMessage(), e); + } + + } else { + throw new DMLRuntimeException("Only String is supported"); + } + } + Row row = RowFactory.create(fieldsArr.toArray()); + return row; + } + } + + // output DF + JavaRDD<Row> newRows = inputDF.rdd().toJavaRDD().zipWithIndex().map(new StringToVector()); + Dataset<Row> outDF = sqlContext.createDataFrame(newRows.rdd(), DataTypes.createStructType(newSchema)); + + return outDF; + } } http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/67f16c46/src/test/java/org/apache/sysml/test/integration/conversion/RDDConverterUtilsExtTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/conversion/RDDConverterUtilsExtTest.java b/src/test/java/org/apache/sysml/test/integration/conversion/RDDConverterUtilsExtTest.java new file mode 100644 index 0000000..7a69423 --- /dev/null +++ b/src/test/java/org/apache/sysml/test/integration/conversion/RDDConverterUtilsExtTest.java @@ -0,0 +1,160 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.test.integration.conversion; + +import static org.junit.Assert.assertTrue; + +import java.util.ArrayList; +import java.util.List; + +import org.apache.spark.SparkConf; +import org.apache.spark.SparkException; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.apache.sysml.runtime.DMLRuntimeException; +import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext; +import org.apache.sysml.runtime.instructions.spark.utils.RDDConverterUtilsExt; +import org.apache.sysml.test.integration.AutomatedTestBase; +import org.junit.After; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; + +public class RDDConverterUtilsExtTest extends AutomatedTestBase { + + private static SparkConf conf; + private static JavaSparkContext sc; + + @BeforeClass + public static void setUpClass() { + if (conf == null) + conf = SparkExecutionContext.createSystemMLSparkConf().setAppName("RDDConverterUtilsExtTest") + .setMaster("local"); + if (sc == null) + sc = new JavaSparkContext(conf); + } + + @Override + public void setUp() { + // no setup required + } + + /** + * Convert a basic String to a spark.sql.Row. + */ + static class StringToRow implements Function<String, Row> { + private static final long serialVersionUID = 3945939649355731805L; + + @Override + public Row call(String str) throws Exception { + return RowFactory.create(str); + } + } + + @Test + public void testStringDataFrameToVectorDataFrame() throws DMLRuntimeException { + List<String> list = new ArrayList<String>(); + list.add("((1.2, 4.3, 3.4))"); + list.add("(1.2, 3.4, 2.2)"); + list.add("[[1.2, 34.3, 1.2, 1.25]]"); + list.add("[1.2, 3.4]"); + JavaRDD<String> javaRddString = sc.parallelize(list); + JavaRDD<Row> javaRddRow = javaRddString.map(new StringToRow()); + SQLContext sqlContext = new SQLContext(sc); + List<StructField> fields = new ArrayList<StructField>(); + fields.add(DataTypes.createStructField("C1", DataTypes.StringType, true)); + StructType schema = DataTypes.createStructType(fields); + Dataset<Row> inDF = sqlContext.createDataFrame(javaRddRow, schema); + Dataset<Row> outDF = RDDConverterUtilsExt.stringDataFrameToVectorDataFrame(sqlContext, inDF); + + List<String> expectedResults = new ArrayList<String>(); + expectedResults.add("[[1.2,4.3,3.4]]"); + expectedResults.add("[[1.2,3.4,2.2]]"); + expectedResults.add("[[1.2,34.3,1.2,1.25]]"); + expectedResults.add("[[1.2,3.4]]"); + + List<Row> outputList = outDF.collectAsList(); + for (Row row : outputList) { + assertTrue("Expected results don't contain: " + row, expectedResults.contains(row.toString())); + } + } + + @Test + public void testStringDataFrameToVectorDataFrameNull() throws DMLRuntimeException { + List<String> list = new ArrayList<String>(); + list.add("[1.2, 3.4]"); + list.add(null); + JavaRDD<String> javaRddString = sc.parallelize(list); + JavaRDD<Row> javaRddRow = javaRddString.map(new StringToRow()); + SQLContext sqlContext = new SQLContext(sc); + List<StructField> fields = new ArrayList<StructField>(); + fields.add(DataTypes.createStructField("C1", DataTypes.StringType, true)); + StructType schema = DataTypes.createStructType(fields); + Dataset<Row> inDF = sqlContext.createDataFrame(javaRddRow, schema); + Dataset<Row> outDF = RDDConverterUtilsExt.stringDataFrameToVectorDataFrame(sqlContext, inDF); + + List<String> expectedResults = new ArrayList<String>(); + expectedResults.add("[[1.2,3.4]]"); + expectedResults.add("[null]"); + + List<Row> outputList = outDF.collectAsList(); + for (Row row : outputList) { + assertTrue("Expected results don't contain: " + row, expectedResults.contains(row.toString())); + } + } + + @Test(expected = SparkException.class) + public void testStringDataFrameToVectorDataFrameNonNumbers() throws DMLRuntimeException { + List<String> list = new ArrayList<String>(); + list.add("[cheeseburger,fries]"); + JavaRDD<String> javaRddString = sc.parallelize(list); + JavaRDD<Row> javaRddRow = javaRddString.map(new StringToRow()); + SQLContext sqlContext = new SQLContext(sc); + List<StructField> fields = new ArrayList<StructField>(); + fields.add(DataTypes.createStructField("C1", DataTypes.StringType, true)); + StructType schema = DataTypes.createStructType(fields); + Dataset<Row> inDF = sqlContext.createDataFrame(javaRddRow, schema); + Dataset<Row> outDF = RDDConverterUtilsExt.stringDataFrameToVectorDataFrame(sqlContext, inDF); + // trigger evaluation to throw exception + outDF.collectAsList(); + } + + @After + public void tearDown() { + super.tearDown(); + } + + @AfterClass + public static void tearDownClass() { + // stop spark context to allow single jvm tests (otherwise the + // next test that tries to create a SparkContext would fail) + sc.stop(); + sc = null; + conf = null; + } +} http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/67f16c46/src/test_suites/java/org/apache/sysml/test/integration/conversion/ZPackageSuite.java ---------------------------------------------------------------------- diff --git a/src/test_suites/java/org/apache/sysml/test/integration/conversion/ZPackageSuite.java b/src/test_suites/java/org/apache/sysml/test/integration/conversion/ZPackageSuite.java new file mode 100644 index 0000000..b8ab13d --- /dev/null +++ b/src/test_suites/java/org/apache/sysml/test/integration/conversion/ZPackageSuite.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.test.integration.conversion; + +import org.junit.runner.RunWith; +import org.junit.runners.Suite; + +/** Group together the tests in this package/related subpackages into a single suite so that the Maven build + * won't run two of them at once. */ +@RunWith(Suite.class) [email protected]({ + org.apache.sysml.test.integration.conversion.RDDConverterUtilsExtTest.class +}) + + +/** This class is just a holder for the above JUnit annotations. */ +public class ZPackageSuite { + +}
