Repository: incubator-systemml
Updated Branches:
  refs/heads/master 03fdf0432 -> 0c85c1e52


[SYSTEMML-1277] MLContext implicitly convert mllib Vector to ml Vector

Implicitly convert dataframe mllib.Vector columns to ml.Vector columns in
MLContext API.

Closes #397.


Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: 
http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/0c85c1e5
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/0c85c1e5
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/0c85c1e5

Branch: refs/heads/master
Commit: 0c85c1e52e02ad740f3d4cab5a7e4bf7258061e1
Parents: 03fdf04
Author: Deron Eriksson <[email protected]>
Authored: Thu Feb 16 21:11:43 2017 -0800
Committer: Deron Eriksson <[email protected]>
Committed: Thu Feb 16 21:11:43 2017 -0800

----------------------------------------------------------------------
 .../sysml/api/mlcontext/MLContextUtil.java      |   6 +-
 .../integration/mlcontext/MLContextTest.java    | 122 ++++++++++++++++++-
 2 files changed, 124 insertions(+), 4 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/0c85c1e5/src/main/java/org/apache/sysml/api/mlcontext/MLContextUtil.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/mlcontext/MLContextUtil.java 
b/src/main/java/org/apache/sysml/api/mlcontext/MLContextUtil.java
index c44843e..22595c0 100644
--- a/src/main/java/org/apache/sysml/api/mlcontext/MLContextUtil.java
+++ b/src/main/java/org/apache/sysml/api/mlcontext/MLContextUtil.java
@@ -39,7 +39,7 @@ import org.apache.spark.SparkContext;
 import org.apache.spark.api.java.JavaPairRDD;
 import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.ml.linalg.VectorUDT;
+import org.apache.spark.mllib.util.MLUtils;
 import org.apache.spark.rdd.RDD;
 import org.apache.spark.sql.Dataset;
 import org.apache.spark.sql.Row;
@@ -505,6 +505,7 @@ public final class MLContextUtil {
                        @SuppressWarnings("unchecked")
                        Dataset<Row> dataFrame = (Dataset<Row>) value;
 
+                       dataFrame = MLUtils.convertVectorColumnsToML(dataFrame);
                        if (hasMatrixMetadata) {
                                return 
MLContextConversionUtil.dataFrameToMatrixObject(name, dataFrame, 
(MatrixMetadata) metadata);
                        } else if (hasFrameMetadata) {
@@ -598,7 +599,8 @@ public final class MLContextUtil {
                for (StructField field : fields) {
                        DataType dataType = field.dataType();
                        if ((dataType != DataTypes.DoubleType) && (dataType != 
DataTypes.IntegerType)
-                                       && (dataType != DataTypes.LongType) && 
(!(dataType instanceof VectorUDT))) {
+                                       && (dataType != DataTypes.LongType) && 
(!(dataType instanceof org.apache.spark.ml.linalg.VectorUDT))
+                                       && (!(dataType instanceof 
org.apache.spark.mllib.linalg.VectorUDT)) ) {
                                // uncomment if we support arrays of doubles 
for matrices
                                // if (dataType instanceof ArrayType) {
                                // ArrayType arrayType = (ArrayType) dataType;

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/0c85c1e5/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTest.java 
b/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTest.java
index abea5be..f6ef208 100644
--- 
a/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTest.java
+++ 
b/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTest.java
@@ -707,7 +707,57 @@ public class MLContextTest extends AutomatedTestBase {
 
                MatrixMetadata mm = new 
MatrixMetadata(MatrixFormat.DF_VECTOR_WITH_INDEX);
 
-               Script script = dml("print('sum: ' + sum(M))").in("M", 
dataFrame, mm);
+               Script script = pydml("print('sum: ' + sum(M))").in("M", 
dataFrame, mm);
+               setExpectedStdOut("sum: 45.0");
+               ml.execute(script);
+       }
+
+       @Test
+       public void testDataFrameSumDMLMllibVectorWithIDColumn() {
+               System.out.println("MLContextTest - DataFrame sum DML, mllib 
vector with ID column");
+
+               List<Tuple2<Double, org.apache.spark.mllib.linalg.Vector>> list 
= new ArrayList<Tuple2<Double, org.apache.spark.mllib.linalg.Vector>>();
+               list.add(new Tuple2<Double, 
org.apache.spark.mllib.linalg.Vector>(1.0, 
org.apache.spark.mllib.linalg.Vectors.dense(1.0, 2.0, 3.0)));
+               list.add(new Tuple2<Double, 
org.apache.spark.mllib.linalg.Vector>(2.0, 
org.apache.spark.mllib.linalg.Vectors.dense(4.0, 5.0, 6.0)));
+               list.add(new Tuple2<Double, 
org.apache.spark.mllib.linalg.Vector>(3.0, 
org.apache.spark.mllib.linalg.Vectors.dense(7.0, 8.0, 9.0)));
+               JavaRDD<Tuple2<Double, org.apache.spark.mllib.linalg.Vector>> 
javaRddTuple = sc.parallelize(list);
+
+               JavaRDD<Row> javaRddRow = javaRddTuple.map(new 
DoubleMllibVectorRow());
+               SparkSession sparkSession = 
SparkSession.builder().sparkContext(sc.sc()).getOrCreate();
+               List<StructField> fields = new ArrayList<StructField>();
+               
fields.add(DataTypes.createStructField(RDDConverterUtils.DF_ID_COLUMN, 
DataTypes.DoubleType, true));
+               fields.add(DataTypes.createStructField("C1", new 
org.apache.spark.mllib.linalg.VectorUDT(), true));
+               StructType schema = DataTypes.createStructType(fields);
+               Dataset<Row> dataFrame = 
sparkSession.createDataFrame(javaRddRow, schema);
+
+               MatrixMetadata mm = new 
MatrixMetadata(MatrixFormat.DF_VECTOR_WITH_INDEX);
+
+               Script script = dml("print('sum: ' + sum(M));").in("M", 
dataFrame, mm);
+               setExpectedStdOut("sum: 45.0");
+               ml.execute(script);
+       }
+
+       @Test
+       public void testDataFrameSumPYDMLMllibVectorWithIDColumn() {
+               System.out.println("MLContextTest - DataFrame sum PYDML, mllib 
vector with ID column");
+
+               List<Tuple2<Double, org.apache.spark.mllib.linalg.Vector>> list 
= new ArrayList<Tuple2<Double, org.apache.spark.mllib.linalg.Vector>>();
+               list.add(new Tuple2<Double, 
org.apache.spark.mllib.linalg.Vector>(1.0, 
org.apache.spark.mllib.linalg.Vectors.dense(1.0, 2.0, 3.0)));
+               list.add(new Tuple2<Double, 
org.apache.spark.mllib.linalg.Vector>(2.0, 
org.apache.spark.mllib.linalg.Vectors.dense(4.0, 5.0, 6.0)));
+               list.add(new Tuple2<Double, 
org.apache.spark.mllib.linalg.Vector>(3.0, 
org.apache.spark.mllib.linalg.Vectors.dense(7.0, 8.0, 9.0)));
+               JavaRDD<Tuple2<Double, org.apache.spark.mllib.linalg.Vector>> 
javaRddTuple = sc.parallelize(list);
+
+               JavaRDD<Row> javaRddRow = javaRddTuple.map(new 
DoubleMllibVectorRow());
+               SparkSession sparkSession = 
SparkSession.builder().sparkContext(sc.sc()).getOrCreate();
+               List<StructField> fields = new ArrayList<StructField>();
+               
fields.add(DataTypes.createStructField(RDDConverterUtils.DF_ID_COLUMN, 
DataTypes.DoubleType, true));
+               fields.add(DataTypes.createStructField("C1", new 
org.apache.spark.mllib.linalg.VectorUDT(), true));
+               StructType schema = DataTypes.createStructType(fields);
+               Dataset<Row> dataFrame = 
sparkSession.createDataFrame(javaRddRow, schema);
+
+               MatrixMetadata mm = new 
MatrixMetadata(MatrixFormat.DF_VECTOR_WITH_INDEX);
+
+               Script script = pydml("print('sum: ' + sum(M))").in("M", 
dataFrame, mm);
                setExpectedStdOut("sum: 45.0");
                ml.execute(script);
        }
@@ -755,7 +805,55 @@ public class MLContextTest extends AutomatedTestBase {
 
                MatrixMetadata mm = new MatrixMetadata(MatrixFormat.DF_VECTOR);
 
-               Script script = dml("print('sum: ' + sum(M))").in("M", 
dataFrame, mm);
+               Script script = pydml("print('sum: ' + sum(M))").in("M", 
dataFrame, mm);
+               setExpectedStdOut("sum: 45.0");
+               ml.execute(script);
+       }
+
+       @Test
+       public void testDataFrameSumDMLMllibVectorWithNoIDColumn() {
+               System.out.println("MLContextTest - DataFrame sum DML, mllib 
vector with no ID column");
+
+               List<org.apache.spark.mllib.linalg.Vector> list = new 
ArrayList<org.apache.spark.mllib.linalg.Vector>();
+               list.add(org.apache.spark.mllib.linalg.Vectors.dense(1.0, 2.0, 
3.0));
+               list.add(org.apache.spark.mllib.linalg.Vectors.dense(4.0, 5.0, 
6.0));
+               list.add(org.apache.spark.mllib.linalg.Vectors.dense(7.0, 8.0, 
9.0));
+               JavaRDD<org.apache.spark.mllib.linalg.Vector> javaRddVector = 
sc.parallelize(list);
+
+               JavaRDD<Row> javaRddRow = javaRddVector.map(new 
MllibVectorRow());
+               SparkSession sparkSession = 
SparkSession.builder().sparkContext(sc.sc()).getOrCreate();
+               List<StructField> fields = new ArrayList<StructField>();
+               fields.add(DataTypes.createStructField("C1", new 
org.apache.spark.mllib.linalg.VectorUDT(), true));
+               StructType schema = DataTypes.createStructType(fields);
+               Dataset<Row> dataFrame = 
sparkSession.createDataFrame(javaRddRow, schema);
+
+               MatrixMetadata mm = new MatrixMetadata(MatrixFormat.DF_VECTOR);
+
+               Script script = dml("print('sum: ' + sum(M));").in("M", 
dataFrame, mm);
+               setExpectedStdOut("sum: 45.0");
+               ml.execute(script);
+       }
+
+       @Test
+       public void testDataFrameSumPYDMLMllibVectorWithNoIDColumn() {
+               System.out.println("MLContextTest - DataFrame sum PYDML, mllib 
vector with no ID column");
+
+               List<org.apache.spark.mllib.linalg.Vector> list = new 
ArrayList<org.apache.spark.mllib.linalg.Vector>();
+               list.add(org.apache.spark.mllib.linalg.Vectors.dense(1.0, 2.0, 
3.0));
+               list.add(org.apache.spark.mllib.linalg.Vectors.dense(4.0, 5.0, 
6.0));
+               list.add(org.apache.spark.mllib.linalg.Vectors.dense(7.0, 8.0, 
9.0));
+               JavaRDD<org.apache.spark.mllib.linalg.Vector> javaRddVector = 
sc.parallelize(list);
+
+               JavaRDD<Row> javaRddRow = javaRddVector.map(new 
MllibVectorRow());
+               SparkSession sparkSession = 
SparkSession.builder().sparkContext(sc.sc()).getOrCreate();
+               List<StructField> fields = new ArrayList<StructField>();
+               fields.add(DataTypes.createStructField("C1", new 
org.apache.spark.mllib.linalg.VectorUDT(), true));
+               StructType schema = DataTypes.createStructType(fields);
+               Dataset<Row> dataFrame = 
sparkSession.createDataFrame(javaRddRow, schema);
+
+               MatrixMetadata mm = new MatrixMetadata(MatrixFormat.DF_VECTOR);
+
+               Script script = pydml("print('sum: ' + sum(M))").in("M", 
dataFrame, mm);
                setExpectedStdOut("sum: 45.0");
                ml.execute(script);
        }
@@ -771,6 +869,17 @@ public class MLContextTest extends AutomatedTestBase {
                }
        }
 
+       static class DoubleMllibVectorRow implements Function<Tuple2<Double, 
org.apache.spark.mllib.linalg.Vector>, Row> {
+               private static final long serialVersionUID = 
-3121178154451876165L;
+
+               @Override
+               public Row call(Tuple2<Double, 
org.apache.spark.mllib.linalg.Vector> tup) throws Exception {
+                       Double doub = tup._1();
+                       org.apache.spark.mllib.linalg.Vector vect = tup._2();
+                       return RowFactory.create(doub, vect);
+               }
+       }
+
        static class VectorRow implements Function<Vector, Row> {
                private static final long serialVersionUID = 
7077761802433569068L;
 
@@ -780,6 +889,15 @@ public class MLContextTest extends AutomatedTestBase {
                }
        }
 
+       static class MllibVectorRow implements 
Function<org.apache.spark.mllib.linalg.Vector, Row> {
+               private static final long serialVersionUID = 
-408929813562996706L;
+
+               @Override
+               public Row call(org.apache.spark.mllib.linalg.Vector vect) 
throws Exception {
+                       return RowFactory.create(vect);
+               }
+       }
+
        static class CommaSeparatedValueStringToRow implements Function<String, 
Row> {
                private static final long serialVersionUID = 
-7871020122671747808L;
 

Reply via email to