Repository: incubator-systemml Updated Branches: refs/heads/master 03fdf0432 -> 0c85c1e52
[SYSTEMML-1277] MLContext implicitly convert mllib Vector to ml Vector Implicitly convert dataframe mllib.Vector columns to ml.Vector columns in MLContext API. Closes #397. Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/0c85c1e5 Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/0c85c1e5 Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/0c85c1e5 Branch: refs/heads/master Commit: 0c85c1e52e02ad740f3d4cab5a7e4bf7258061e1 Parents: 03fdf04 Author: Deron Eriksson <[email protected]> Authored: Thu Feb 16 21:11:43 2017 -0800 Committer: Deron Eriksson <[email protected]> Committed: Thu Feb 16 21:11:43 2017 -0800 ---------------------------------------------------------------------- .../sysml/api/mlcontext/MLContextUtil.java | 6 +- .../integration/mlcontext/MLContextTest.java | 122 ++++++++++++++++++- 2 files changed, 124 insertions(+), 4 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/0c85c1e5/src/main/java/org/apache/sysml/api/mlcontext/MLContextUtil.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/api/mlcontext/MLContextUtil.java b/src/main/java/org/apache/sysml/api/mlcontext/MLContextUtil.java index c44843e..22595c0 100644 --- a/src/main/java/org/apache/sysml/api/mlcontext/MLContextUtil.java +++ b/src/main/java/org/apache/sysml/api/mlcontext/MLContextUtil.java @@ -39,7 +39,7 @@ import org.apache.spark.SparkContext; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.ml.linalg.VectorUDT; +import org.apache.spark.mllib.util.MLUtils; import org.apache.spark.rdd.RDD; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; @@ -505,6 +505,7 @@ public final class MLContextUtil { @SuppressWarnings("unchecked") Dataset<Row> dataFrame = (Dataset<Row>) value; + dataFrame = MLUtils.convertVectorColumnsToML(dataFrame); if (hasMatrixMetadata) { return MLContextConversionUtil.dataFrameToMatrixObject(name, dataFrame, (MatrixMetadata) metadata); } else if (hasFrameMetadata) { @@ -598,7 +599,8 @@ public final class MLContextUtil { for (StructField field : fields) { DataType dataType = field.dataType(); if ((dataType != DataTypes.DoubleType) && (dataType != DataTypes.IntegerType) - && (dataType != DataTypes.LongType) && (!(dataType instanceof VectorUDT))) { + && (dataType != DataTypes.LongType) && (!(dataType instanceof org.apache.spark.ml.linalg.VectorUDT)) + && (!(dataType instanceof org.apache.spark.mllib.linalg.VectorUDT)) ) { // uncomment if we support arrays of doubles for matrices // if (dataType instanceof ArrayType) { // ArrayType arrayType = (ArrayType) dataType; http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/0c85c1e5/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTest.java b/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTest.java index abea5be..f6ef208 100644 --- a/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTest.java +++ b/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTest.java @@ -707,7 +707,57 @@ public class MLContextTest extends AutomatedTestBase { MatrixMetadata mm = new MatrixMetadata(MatrixFormat.DF_VECTOR_WITH_INDEX); - Script script = dml("print('sum: ' + sum(M))").in("M", dataFrame, mm); + Script script = pydml("print('sum: ' + sum(M))").in("M", dataFrame, mm); + setExpectedStdOut("sum: 45.0"); + ml.execute(script); + } + + @Test + public void testDataFrameSumDMLMllibVectorWithIDColumn() { + System.out.println("MLContextTest - DataFrame sum DML, mllib vector with ID column"); + + List<Tuple2<Double, org.apache.spark.mllib.linalg.Vector>> list = new ArrayList<Tuple2<Double, org.apache.spark.mllib.linalg.Vector>>(); + list.add(new Tuple2<Double, org.apache.spark.mllib.linalg.Vector>(1.0, org.apache.spark.mllib.linalg.Vectors.dense(1.0, 2.0, 3.0))); + list.add(new Tuple2<Double, org.apache.spark.mllib.linalg.Vector>(2.0, org.apache.spark.mllib.linalg.Vectors.dense(4.0, 5.0, 6.0))); + list.add(new Tuple2<Double, org.apache.spark.mllib.linalg.Vector>(3.0, org.apache.spark.mllib.linalg.Vectors.dense(7.0, 8.0, 9.0))); + JavaRDD<Tuple2<Double, org.apache.spark.mllib.linalg.Vector>> javaRddTuple = sc.parallelize(list); + + JavaRDD<Row> javaRddRow = javaRddTuple.map(new DoubleMllibVectorRow()); + SparkSession sparkSession = SparkSession.builder().sparkContext(sc.sc()).getOrCreate(); + List<StructField> fields = new ArrayList<StructField>(); + fields.add(DataTypes.createStructField(RDDConverterUtils.DF_ID_COLUMN, DataTypes.DoubleType, true)); + fields.add(DataTypes.createStructField("C1", new org.apache.spark.mllib.linalg.VectorUDT(), true)); + StructType schema = DataTypes.createStructType(fields); + Dataset<Row> dataFrame = sparkSession.createDataFrame(javaRddRow, schema); + + MatrixMetadata mm = new MatrixMetadata(MatrixFormat.DF_VECTOR_WITH_INDEX); + + Script script = dml("print('sum: ' + sum(M));").in("M", dataFrame, mm); + setExpectedStdOut("sum: 45.0"); + ml.execute(script); + } + + @Test + public void testDataFrameSumPYDMLMllibVectorWithIDColumn() { + System.out.println("MLContextTest - DataFrame sum PYDML, mllib vector with ID column"); + + List<Tuple2<Double, org.apache.spark.mllib.linalg.Vector>> list = new ArrayList<Tuple2<Double, org.apache.spark.mllib.linalg.Vector>>(); + list.add(new Tuple2<Double, org.apache.spark.mllib.linalg.Vector>(1.0, org.apache.spark.mllib.linalg.Vectors.dense(1.0, 2.0, 3.0))); + list.add(new Tuple2<Double, org.apache.spark.mllib.linalg.Vector>(2.0, org.apache.spark.mllib.linalg.Vectors.dense(4.0, 5.0, 6.0))); + list.add(new Tuple2<Double, org.apache.spark.mllib.linalg.Vector>(3.0, org.apache.spark.mllib.linalg.Vectors.dense(7.0, 8.0, 9.0))); + JavaRDD<Tuple2<Double, org.apache.spark.mllib.linalg.Vector>> javaRddTuple = sc.parallelize(list); + + JavaRDD<Row> javaRddRow = javaRddTuple.map(new DoubleMllibVectorRow()); + SparkSession sparkSession = SparkSession.builder().sparkContext(sc.sc()).getOrCreate(); + List<StructField> fields = new ArrayList<StructField>(); + fields.add(DataTypes.createStructField(RDDConverterUtils.DF_ID_COLUMN, DataTypes.DoubleType, true)); + fields.add(DataTypes.createStructField("C1", new org.apache.spark.mllib.linalg.VectorUDT(), true)); + StructType schema = DataTypes.createStructType(fields); + Dataset<Row> dataFrame = sparkSession.createDataFrame(javaRddRow, schema); + + MatrixMetadata mm = new MatrixMetadata(MatrixFormat.DF_VECTOR_WITH_INDEX); + + Script script = pydml("print('sum: ' + sum(M))").in("M", dataFrame, mm); setExpectedStdOut("sum: 45.0"); ml.execute(script); } @@ -755,7 +805,55 @@ public class MLContextTest extends AutomatedTestBase { MatrixMetadata mm = new MatrixMetadata(MatrixFormat.DF_VECTOR); - Script script = dml("print('sum: ' + sum(M))").in("M", dataFrame, mm); + Script script = pydml("print('sum: ' + sum(M))").in("M", dataFrame, mm); + setExpectedStdOut("sum: 45.0"); + ml.execute(script); + } + + @Test + public void testDataFrameSumDMLMllibVectorWithNoIDColumn() { + System.out.println("MLContextTest - DataFrame sum DML, mllib vector with no ID column"); + + List<org.apache.spark.mllib.linalg.Vector> list = new ArrayList<org.apache.spark.mllib.linalg.Vector>(); + list.add(org.apache.spark.mllib.linalg.Vectors.dense(1.0, 2.0, 3.0)); + list.add(org.apache.spark.mllib.linalg.Vectors.dense(4.0, 5.0, 6.0)); + list.add(org.apache.spark.mllib.linalg.Vectors.dense(7.0, 8.0, 9.0)); + JavaRDD<org.apache.spark.mllib.linalg.Vector> javaRddVector = sc.parallelize(list); + + JavaRDD<Row> javaRddRow = javaRddVector.map(new MllibVectorRow()); + SparkSession sparkSession = SparkSession.builder().sparkContext(sc.sc()).getOrCreate(); + List<StructField> fields = new ArrayList<StructField>(); + fields.add(DataTypes.createStructField("C1", new org.apache.spark.mllib.linalg.VectorUDT(), true)); + StructType schema = DataTypes.createStructType(fields); + Dataset<Row> dataFrame = sparkSession.createDataFrame(javaRddRow, schema); + + MatrixMetadata mm = new MatrixMetadata(MatrixFormat.DF_VECTOR); + + Script script = dml("print('sum: ' + sum(M));").in("M", dataFrame, mm); + setExpectedStdOut("sum: 45.0"); + ml.execute(script); + } + + @Test + public void testDataFrameSumPYDMLMllibVectorWithNoIDColumn() { + System.out.println("MLContextTest - DataFrame sum PYDML, mllib vector with no ID column"); + + List<org.apache.spark.mllib.linalg.Vector> list = new ArrayList<org.apache.spark.mllib.linalg.Vector>(); + list.add(org.apache.spark.mllib.linalg.Vectors.dense(1.0, 2.0, 3.0)); + list.add(org.apache.spark.mllib.linalg.Vectors.dense(4.0, 5.0, 6.0)); + list.add(org.apache.spark.mllib.linalg.Vectors.dense(7.0, 8.0, 9.0)); + JavaRDD<org.apache.spark.mllib.linalg.Vector> javaRddVector = sc.parallelize(list); + + JavaRDD<Row> javaRddRow = javaRddVector.map(new MllibVectorRow()); + SparkSession sparkSession = SparkSession.builder().sparkContext(sc.sc()).getOrCreate(); + List<StructField> fields = new ArrayList<StructField>(); + fields.add(DataTypes.createStructField("C1", new org.apache.spark.mllib.linalg.VectorUDT(), true)); + StructType schema = DataTypes.createStructType(fields); + Dataset<Row> dataFrame = sparkSession.createDataFrame(javaRddRow, schema); + + MatrixMetadata mm = new MatrixMetadata(MatrixFormat.DF_VECTOR); + + Script script = pydml("print('sum: ' + sum(M))").in("M", dataFrame, mm); setExpectedStdOut("sum: 45.0"); ml.execute(script); } @@ -771,6 +869,17 @@ public class MLContextTest extends AutomatedTestBase { } } + static class DoubleMllibVectorRow implements Function<Tuple2<Double, org.apache.spark.mllib.linalg.Vector>, Row> { + private static final long serialVersionUID = -3121178154451876165L; + + @Override + public Row call(Tuple2<Double, org.apache.spark.mllib.linalg.Vector> tup) throws Exception { + Double doub = tup._1(); + org.apache.spark.mllib.linalg.Vector vect = tup._2(); + return RowFactory.create(doub, vect); + } + } + static class VectorRow implements Function<Vector, Row> { private static final long serialVersionUID = 7077761802433569068L; @@ -780,6 +889,15 @@ public class MLContextTest extends AutomatedTestBase { } } + static class MllibVectorRow implements Function<org.apache.spark.mllib.linalg.Vector, Row> { + private static final long serialVersionUID = -408929813562996706L; + + @Override + public Row call(org.apache.spark.mllib.linalg.Vector vect) throws Exception { + return RowFactory.create(vect); + } + } + static class CommaSeparatedValueStringToRow implements Function<String, Row> { private static final long serialVersionUID = -7871020122671747808L;
