Repository: incubator-systemml
Updated Branches:
  refs/heads/master 9f12b5c66 -> 97dee8fba


[SYSTEMML-834] Improve MLContext DataFrame support

Closes #218.


Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: 
http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/97dee8fb
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/97dee8fb
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/97dee8fb

Branch: refs/heads/master
Commit: 97dee8fba7252f9b868e23d69f23c36053f48445
Parents: 9f12b5c
Author: Deron Eriksson <[email protected]>
Authored: Thu Aug 25 14:33:00 2016 -0700
Committer: Deron Eriksson <[email protected]>
Committed: Thu Aug 25 14:33:00 2016 -0700

----------------------------------------------------------------------
 .../sysml/api/mlcontext/BinaryBlockMatrix.java  |  21 +-
 .../api/mlcontext/MLContextConversionUtil.java  |  66 +++-
 .../apache/sysml/api/mlcontext/MLResults.java   | 136 ++++++-
 .../org/apache/sysml/api/mlcontext/Matrix.java  |  46 ++-
 .../sysml/api/mlcontext/MatrixFormat.java       |  54 ++-
 .../integration/mlcontext/MLContextTest.java    | 390 ++++++++++++++++++-
 6 files changed, 689 insertions(+), 24 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/97dee8fb/src/main/java/org/apache/sysml/api/mlcontext/BinaryBlockMatrix.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/api/mlcontext/BinaryBlockMatrix.java 
b/src/main/java/org/apache/sysml/api/mlcontext/BinaryBlockMatrix.java
index ea6fcf0..b13669d 100644
--- a/src/main/java/org/apache/sysml/api/mlcontext/BinaryBlockMatrix.java
+++ b/src/main/java/org/apache/sysml/api/mlcontext/BinaryBlockMatrix.java
@@ -99,12 +99,21 @@ public class BinaryBlockMatrix {
        public JavaPairRDD<MatrixIndexes, MatrixBlock> getBinaryBlocks() {
                return binaryBlocks;
        }
-       
-       public MatrixBlock getMatrixBlock() throws DMLRuntimeException {
-               MatrixCharacteristics mc = getMatrixCharacteristics();
-               MatrixBlock mb = 
SparkExecutionContext.toMatrixBlock(binaryBlocks, (int) mc.getRows(), (int) 
mc.getCols(), 
-                               mc.getRowsPerBlock(), mc.getColsPerBlock(), 
mc.getNonZeros());
-               return mb;
+
+       /**
+        * Obtain a SystemML binary-block matrix as a {@code MatrixBlock}
+        * 
+        * @return the SystemML binary-block matrix as a {@code MatrixBlock}
+        */
+       public MatrixBlock getMatrixBlock() {
+               try {
+                       MatrixCharacteristics mc = getMatrixCharacteristics();
+                       MatrixBlock mb = 
SparkExecutionContext.toMatrixBlock(binaryBlocks, (int) mc.getRows(), (int) 
mc.getCols(),
+                                       mc.getRowsPerBlock(), 
mc.getColsPerBlock(), mc.getNonZeros());
+                       return mb;
+               } catch (DMLRuntimeException e) {
+                       throw new MLContextException("Exception while getting 
MatrixBlock from binary-block matrix", e);
+               }
        }
 
        /**

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/97dee8fb/src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java 
b/src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java
index 33a5a3c..3a482ef 100644
--- a/src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java
+++ b/src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java
@@ -33,6 +33,7 @@ import org.apache.spark.SparkContext;
 import org.apache.spark.api.java.JavaPairRDD;
 import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.mllib.linalg.Vector;
 import org.apache.spark.rdd.RDD;
 import org.apache.spark.sql.DataFrame;
 import org.apache.spark.sql.Row;
@@ -311,7 +312,14 @@ public class MLContextConversionUtil {
                } else {
                        matrixCharacteristics = new MatrixCharacteristics();
                }
-               determineDataFrameDimensionsIfNeeded(dataFrame, 
matrixCharacteristics);
+
+               if (isDataFrameWithIDColumn(matrixMetadata)) {
+                       dataFrame = dataFrame.sort("ID").drop("ID");
+               }
+
+               boolean isVectorBasedDataFrame = 
isVectorBasedDataFrame(matrixMetadata);
+
+               determineDataFrameDimensionsIfNeeded(dataFrame, 
matrixCharacteristics, isVectorBasedDataFrame);
                if (matrixMetadata != null) {
                        // so external reference can be updated with the 
metadata
                        
matrixMetadata.setMatrixCharacteristics(matrixCharacteristics);
@@ -320,12 +328,50 @@ public class MLContextConversionUtil {
                JavaRDD<Row> javaRDD = dataFrame.javaRDD();
                JavaPairRDD<Row, Long> prepinput = javaRDD.zipWithIndex();
                JavaPairRDD<MatrixIndexes, MatrixBlock> out = 
prepinput.mapPartitionsToPair(new DataFrameToBinaryBlockFunction(
-                               matrixCharacteristics, false));
+                               matrixCharacteristics, isVectorBasedDataFrame));
                out = RDDAggregateUtils.mergeByKey(out);
                return out;
        }
 
        /**
+        * Return whether or not the DataFrame has an ID column.
+        * 
+        * @param matrixMetadata
+        *            the matrix metadata
+        * @return {@code true} if the DataFrame has an ID column, {@code false}
+        *         otherwise.
+        */
+       public static boolean isDataFrameWithIDColumn(MatrixMetadata 
matrixMetadata) {
+               if (matrixMetadata == null) {
+                       return false;
+               }
+               MatrixFormat matrixFormat = matrixMetadata.getMatrixFormat();
+               if (matrixFormat == null) {
+                       return false;
+               }
+               return matrixFormat.hasIDColumn();
+       }
+
+       /**
+        * Return whether or not the DataFrame is vector-based.
+        * 
+        * @param matrixMetadata
+        *            the matrix metadata
+        * @return {@code true} if the DataFrame is vector-based, {@code false}
+        *         otherwise.
+        */
+       public static boolean isVectorBasedDataFrame(MatrixMetadata 
matrixMetadata) {
+               if (matrixMetadata == null) {
+                       return false;
+               }
+               MatrixFormat matrixFormat = matrixMetadata.getMatrixFormat();
+               if (matrixFormat == null) {
+                       return false;
+               }
+               return matrixFormat.isVectorBased();
+       }
+
+       /**
         * If the {@code DataFrame} dimensions aren't present in the
         * {@code MatrixCharacteristics} metadata, determine the dimensions and
         * place them in the {@code MatrixCharacteristics} metadata.
@@ -334,20 +380,28 @@ public class MLContextConversionUtil {
         *            the Spark {@code DataFrame}
         * @param matrixCharacteristics
         *            the matrix metadata
+        * @param vectorBased
+        *            is the DataFrame vector-based
         */
        public static void determineDataFrameDimensionsIfNeeded(DataFrame 
dataFrame,
-                       MatrixCharacteristics matrixCharacteristics) {
+                       MatrixCharacteristics matrixCharacteristics, boolean 
vectorBased) {
                if (!matrixCharacteristics.dimsKnown(true)) {
-                       // only available to the new MLContext API, not the old 
API
                        MLContext activeMLContext = (MLContext) 
MLContextProxy.getActiveMLContext();
                        SparkContext sparkContext = 
activeMLContext.getSparkContext();
                        @SuppressWarnings("resource")
                        JavaSparkContext javaSparkContext = new 
JavaSparkContext(sparkContext);
 
                        Accumulator<Double> aNnz = 
javaSparkContext.accumulator(0L);
-                       JavaRDD<Row> javaRDD = dataFrame.javaRDD().map(new 
DataFrameAnalysisFunction(aNnz, false));
+                       JavaRDD<Row> javaRDD = dataFrame.javaRDD().map(new 
DataFrameAnalysisFunction(aNnz, vectorBased));
                        long numRows = javaRDD.count();
-                       long numColumns = dataFrame.columns().length;
+                       long numColumns;
+                       if (vectorBased) {
+                               Vector v = (Vector) javaRDD.first().get(0);
+                               numColumns = v.size();
+                       } else {
+                               numColumns = dataFrame.columns().length;
+                       }
+
                        long numNonZeros = UtilFunctions.toLong(aNnz.value());
                        matrixCharacteristics.set(numRows, numColumns, 
matrixCharacteristics.getRowsPerBlock(),
                                        
matrixCharacteristics.getColsPerBlock(), numNonZeros);

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/97dee8fb/src/main/java/org/apache/sysml/api/mlcontext/MLResults.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/mlcontext/MLResults.java 
b/src/main/java/org/apache/sysml/api/mlcontext/MLResults.java
index 31798e0..dbc8f5d 100644
--- a/src/main/java/org/apache/sysml/api/mlcontext/MLResults.java
+++ b/src/main/java/org/apache/sysml/api/mlcontext/MLResults.java
@@ -236,7 +236,7 @@ public class MLResults {
        }
 
        /**
-        * Obtain an output as a {@code DataFrame} of doubles.
+        * Obtain an output as a {@code DataFrame} of doubles with an ID column.
         * <p>
         * The following matrix in DML:
         * </p>
@@ -245,13 +245,13 @@ public class MLResults {
         * <p>
         * is equivalent to the following {@code DataFrame} of doubles:
         * </p>
-        * <code>[0.0,1.0,2.0]
-        * <br>[1.0,3.0,4.0]
+        * <code>[1.0,1.0,2.0]
+        * <br>[2.0,3.0,4.0]
         * </code>
         *
         * @param outputName
         *            the name of the output
-        * @return the output as a {@code DataFrame} of doubles
+        * @return the output as a {@code DataFrame} of doubles with an ID 
column
         */
        public DataFrame getDataFrame(String outputName) {
                MatrixObject mo = getMatrixObject(outputName);
@@ -259,6 +259,35 @@ public class MLResults {
                return df;
        }
 
+       /**
+        * Obtain an output as a {@code DataFrame} of doubles or vectors with 
an ID
+        * column.
+        * <p>
+        * The following matrix in DML:
+        * </p>
+        * <code>M = full('1 2 3 4', rows=2, cols=2);
+        * </code>
+        * <p>
+        * is equivalent to the following {@code DataFrame} of doubles:
+        * </p>
+        * <code>[1.0,1.0,2.0]
+        * <br>[2.0,3.0,4.0]
+        * </code>
+        * <p>
+        * or the following {@code DataFrame} of vectors:
+        * </p>
+        * <code>[1.0,[1.0,2.0]]
+        * <br>[2.0,[3.0,4.0]]
+        * </code>
+        *
+        * @param outputName
+        *            the name of the output
+        * @param isVectorDF
+        *            {@code true} for a vector {@code DataFrame}, {@code 
false} for
+        *            a double {@code DataFrame}
+        * @return the output as a {@code DataFrame} of doubles or vectors with 
an
+        *         ID column
+        */
        public DataFrame getDataFrame(String outputName, boolean isVectorDF) {
                MatrixObject mo = getMatrixObject(outputName);
                DataFrame df = 
MLContextConversionUtil.matrixObjectToDataFrame(mo, sparkExecutionContext, 
isVectorDF);
@@ -266,6 +295,104 @@ public class MLResults {
        }
 
        /**
+        * Obtain an output as a {@code DataFrame} of doubles with an ID column.
+        * <p>
+        * The following matrix in DML:
+        * </p>
+        * <code>M = full('1 2 3 4', rows=2, cols=2);
+        * </code>
+        * <p>
+        * is equivalent to the following {@code DataFrame} of doubles:
+        * </p>
+        * <code>[1.0,1.0,2.0]
+        * <br>[2.0,3.0,4.0]
+        * </code>
+        *
+        * @param outputName
+        *            the name of the output
+        * @return the output as a {@code DataFrame} of doubles with an ID 
column
+        */
+       public DataFrame getDataFrameDoubleWithIDColumn(String outputName) {
+               MatrixObject mo = getMatrixObject(outputName);
+               DataFrame df = 
MLContextConversionUtil.matrixObjectToDataFrame(mo, sparkExecutionContext, 
false);
+               return df;
+       }
+
+       /**
+        * Obtain an output as a {@code DataFrame} of vectors with an ID column.
+        * <p>
+        * The following matrix in DML:
+        * </p>
+        * <code>M = full('1 2 3 4', rows=2, cols=2);
+        * </code>
+        * <p>
+        * is equivalent to the following {@code DataFrame} of vectors:
+        * </p>
+        * <code>[1.0,[1.0,2.0]]
+        * <br>[2.0,[3.0,4.0]]
+        * </code>
+        *
+        * @param outputName
+        *            the name of the output
+        * @return the output as a {@code DataFrame} of vectors with an ID 
column
+        */
+       public DataFrame getDataFrameVectorWithIDColumn(String outputName) {
+               MatrixObject mo = getMatrixObject(outputName);
+               DataFrame df = 
MLContextConversionUtil.matrixObjectToDataFrame(mo, sparkExecutionContext, 
true);
+               return df;
+       }
+
+       /**
+        * Obtain an output as a {@code DataFrame} of doubles with no ID column.
+        * <p>
+        * The following matrix in DML:
+        * </p>
+        * <code>M = full('1 2 3 4', rows=2, cols=2);
+        * </code>
+        * <p>
+        * is equivalent to the following {@code DataFrame} of doubles:
+        * </p>
+        * <code>[1.0,2.0]
+        * <br>[3.0,4.0]
+        * </code>
+        *
+        * @param outputName
+        *            the name of the output
+        * @return the output as a {@code DataFrame} of doubles with no ID 
column
+        */
+       public DataFrame getDataFrameDoubleNoIDColumn(String outputName) {
+               MatrixObject mo = getMatrixObject(outputName);
+               DataFrame df = 
MLContextConversionUtil.matrixObjectToDataFrame(mo, sparkExecutionContext, 
false);
+               df = df.sort("ID").drop("ID");
+               return df;
+       }
+
+       /**
+        * Obtain an output as a {@code DataFrame} of vectors with no ID column.
+        * <p>
+        * The following matrix in DML:
+        * </p>
+        * <code>M = full('1 2 3 4', rows=2, cols=2);
+        * </code>
+        * <p>
+        * is equivalent to the following {@code DataFrame} of vectors:
+        * </p>
+        * <code>[[1.0,2.0]]
+        * <br>[[3.0,4.0]]
+        * </code>
+        *
+        * @param outputName
+        *            the name of the output
+        * @return the output as a {@code DataFrame} of vectors with no ID 
column
+        */
+       public DataFrame getDataFrameVectorNoIDColumn(String outputName) {
+               MatrixObject mo = getMatrixObject(outputName);
+               DataFrame df = 
MLContextConversionUtil.matrixObjectToDataFrame(mo, sparkExecutionContext, 
true);
+               df = df.sort("ID").drop("ID");
+               return df;
+       }
+
+       /**
         * Obtain an output as a {@code Matrix}.
         *
         * @param outputName
@@ -278,7 +405,6 @@ public class MLResults {
                return matrix;
        }
 
-
        /**
         * Obtain an output as a {@code BinaryBlockMatrix}.
         *

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/97dee8fb/src/main/java/org/apache/sysml/api/mlcontext/Matrix.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/mlcontext/Matrix.java 
b/src/main/java/org/apache/sysml/api/mlcontext/Matrix.java
index 3ee41b7..abd785c 100644
--- a/src/main/java/org/apache/sysml/api/mlcontext/Matrix.java
+++ b/src/main/java/org/apache/sysml/api/mlcontext/Matrix.java
@@ -103,9 +103,9 @@ public class Matrix {
        }
 
        /**
-        * Obtain the matrix as a {@code DataFrame}
+        * Obtain the matrix as a {@code DataFrame} of doubles with an ID column
         * 
-        * @return the matrix as a {@code DataFrame}
+        * @return the matrix as a {@code DataFrame} of doubles with an ID 
column
         */
        public DataFrame asDataFrame() {
                DataFrame df = 
MLContextConversionUtil.matrixObjectToDataFrame(matrixObject, 
sparkExecutionContext, false);
@@ -113,6 +113,48 @@ public class Matrix {
        }
 
        /**
+        * Obtain the matrix as a {@code DataFrame} of doubles with an ID column
+        * 
+        * @return the matrix as a {@code DataFrame} of doubles with an ID 
column
+        */
+       public DataFrame asDataFrameDoubleWithIDColumn() {
+               DataFrame df = 
MLContextConversionUtil.matrixObjectToDataFrame(matrixObject, 
sparkExecutionContext, false);
+               return df;
+       }
+
+       /**
+        * Obtain the matrix as a {@code DataFrame} of doubles with no ID column
+        * 
+        * @return the matrix as a {@code DataFrame} of doubles with no ID 
column
+        */
+       public DataFrame asDataFrameDoubleNoIDColumn() {
+               DataFrame df = 
MLContextConversionUtil.matrixObjectToDataFrame(matrixObject, 
sparkExecutionContext, false);
+               df = df.sort("ID").drop("ID");
+               return df;
+       }
+
+       /**
+        * Obtain the matrix as a {@code DataFrame} of vectors with an ID column
+        * 
+        * @return the matrix as a {@code DataFrame} of vectors with an ID 
column
+        */
+       public DataFrame asDataFrameVectorWithIDColumn() {
+               DataFrame df = 
MLContextConversionUtil.matrixObjectToDataFrame(matrixObject, 
sparkExecutionContext, true);
+               return df;
+       }
+
+       /**
+        * Obtain the matrix as a {@code DataFrame} of vectors with no ID column
+        * 
+        * @return the matrix as a {@code DataFrame} of vectors with no ID 
column
+        */
+       public DataFrame asDataFrameVectorNoIDColumn() {
+               DataFrame df = 
MLContextConversionUtil.matrixObjectToDataFrame(matrixObject, 
sparkExecutionContext, true);
+               df = df.sort("ID").drop("ID");
+               return df;
+       }
+
+       /**
         * Obtain the matrix as a {@code BinaryBlockMatrix}
         * 
         * @return the matrix as a {@code BinaryBlockMatrix}

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/97dee8fb/src/main/java/org/apache/sysml/api/mlcontext/MatrixFormat.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/mlcontext/MatrixFormat.java 
b/src/main/java/org/apache/sysml/api/mlcontext/MatrixFormat.java
index 50ed634..a7ac395 100644
--- a/src/main/java/org/apache/sysml/api/mlcontext/MatrixFormat.java
+++ b/src/main/java/org/apache/sysml/api/mlcontext/MatrixFormat.java
@@ -34,6 +34,58 @@ public enum MatrixFormat {
         * (I J V) format (sparse). I and J represent matrix coordinates and V
         * represents the value. The I J and V values are space-separated.
         */
-       IJV;
+       IJV,
+
+       /**
+        * DataFrame of doubles with an ID column.
+        */
+       DF_DOUBLES_WITH_ID_COLUMN,
+
+       /**
+        * DataFrame of doubles with no ID column.
+        */
+       DF_DOUBLES_WITH_NO_ID_COLUMN,
+
+       /**
+        * Vector DataFrame with an ID column.
+        */
+       DF_VECTOR_WITH_ID_COLUMN,
+
+       /**
+        * Vector DataFrame with no ID column.
+        */
+       DF_VECTOR_WITH_NO_ID_COLUMN;
+
+       /**
+        * Is the matrix format vector-based?
+        * 
+        * @return {@code true} if matrix is a vector-based DataFrame, {@code 
false}
+        *         otherwise.
+        */
+       public boolean isVectorBased() {
+               if (this == DF_VECTOR_WITH_ID_COLUMN) {
+                       return true;
+               } else if (this == DF_VECTOR_WITH_NO_ID_COLUMN) {
+                       return true;
+               } else {
+                       return false;
+               }
+       }
+
+       /**
+        * Does the DataFrame have an ID column?
+        * 
+        * @return {@code true} if the DataFrame has an ID column, {@code false}
+        *         otherwise.
+        */
+       public boolean hasIDColumn() {
+               if (this == DF_DOUBLES_WITH_ID_COLUMN) {
+                       return true;
+               } else if (this == DF_VECTOR_WITH_ID_COLUMN) {
+                       return true;
+               } else {
+                       return false;
+               }
+       }
 
 }

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/97dee8fb/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTest.java 
b/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTest.java
index e6e1046..7be657b 100644
--- 
a/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTest.java
+++ 
b/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTest.java
@@ -46,6 +46,9 @@ import org.apache.spark.SparkConf;
 import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.api.java.JavaSparkContext;
 import org.apache.spark.api.java.function.Function;
+import org.apache.spark.mllib.linalg.Vector;
+import org.apache.spark.mllib.linalg.VectorUDT;
+import org.apache.spark.mllib.linalg.Vectors;
 import org.apache.spark.rdd.RDD;
 import org.apache.spark.sql.DataFrame;
 import org.apache.spark.sql.Row;
@@ -497,8 +500,8 @@ public class MLContextTest extends AutomatedTestBase {
        }
 
        @Test
-       public void testDataFrameSumDML() {
-               System.out.println("MLContextTest - DataFrame sum DML");
+       public void testDataFrameSumDMLDoublesWithNoIDColumn() {
+               System.out.println("MLContextTest - DataFrame sum DML, doubles 
with no ID column");
 
                List<String> list = new ArrayList<String>();
                list.add("10,20,30");
@@ -521,8 +524,8 @@ public class MLContextTest extends AutomatedTestBase {
        }
 
        @Test
-       public void testDataFrameSumPYDML() {
-               System.out.println("MLContextTest - DataFrame sum PYDML");
+       public void testDataFrameSumPYDMLDoublesWithNoIDColumn() {
+               System.out.println("MLContextTest - DataFrame sum PYDML, 
doubles with no ID column");
 
                List<String> list = new ArrayList<String>();
                list.add("10,20,30");
@@ -544,9 +547,236 @@ public class MLContextTest extends AutomatedTestBase {
                ml.execute(script);
        }
 
+       @Test
+       public void testDataFrameSumDMLDoublesWithIDColumn() {
+               System.out.println("MLContextTest - DataFrame sum DML, doubles 
with ID column");
+
+               List<String> list = new ArrayList<String>();
+               list.add("1,1,2,3");
+               list.add("2,4,5,6");
+               list.add("3,7,8,9");
+               JavaRDD<String> javaRddString = sc.parallelize(list);
+
+               JavaRDD<Row> javaRddRow = javaRddString.map(new 
CommaSeparatedValueStringToRow());
+               SQLContext sqlContext = new SQLContext(sc);
+               List<StructField> fields = new ArrayList<StructField>();
+               fields.add(DataTypes.createStructField("ID", 
DataTypes.StringType, true));
+               fields.add(DataTypes.createStructField("C1", 
DataTypes.StringType, true));
+               fields.add(DataTypes.createStructField("C2", 
DataTypes.StringType, true));
+               fields.add(DataTypes.createStructField("C3", 
DataTypes.StringType, true));
+               StructType schema = DataTypes.createStructType(fields);
+               DataFrame dataFrame = sqlContext.createDataFrame(javaRddRow, 
schema);
+
+               MatrixMetadata mm = new 
MatrixMetadata(MatrixFormat.DF_DOUBLES_WITH_ID_COLUMN);
+
+               Script script = dml("print('sum: ' + sum(M));").in("M", 
dataFrame, mm);
+               setExpectedStdOut("sum: 45.0");
+               ml.execute(script);
+       }
+
+       @Test
+       public void testDataFrameSumPYDMLDoublesWithIDColumn() {
+               System.out.println("MLContextTest - DataFrame sum PYDML, 
doubles with ID column");
+
+               List<String> list = new ArrayList<String>();
+               list.add("1,1,2,3");
+               list.add("2,4,5,6");
+               list.add("3,7,8,9");
+               JavaRDD<String> javaRddString = sc.parallelize(list);
+
+               JavaRDD<Row> javaRddRow = javaRddString.map(new 
CommaSeparatedValueStringToRow());
+               SQLContext sqlContext = new SQLContext(sc);
+               List<StructField> fields = new ArrayList<StructField>();
+               fields.add(DataTypes.createStructField("ID", 
DataTypes.StringType, true));
+               fields.add(DataTypes.createStructField("C1", 
DataTypes.StringType, true));
+               fields.add(DataTypes.createStructField("C2", 
DataTypes.StringType, true));
+               fields.add(DataTypes.createStructField("C3", 
DataTypes.StringType, true));
+               StructType schema = DataTypes.createStructType(fields);
+               DataFrame dataFrame = sqlContext.createDataFrame(javaRddRow, 
schema);
+
+               MatrixMetadata mm = new 
MatrixMetadata(MatrixFormat.DF_DOUBLES_WITH_ID_COLUMN);
+
+               Script script = pydml("print('sum: ' + sum(M))").in("M", 
dataFrame, mm);
+               setExpectedStdOut("sum: 45.0");
+               ml.execute(script);
+       }
+
+       @Test
+       public void testDataFrameSumDMLDoublesWithIDColumnSortCheck() {
+               System.out.println("MLContextTest - DataFrame sum DML, doubles 
with ID column sort check");
+
+               List<String> list = new ArrayList<String>();
+               list.add("3,7,8,9");
+               list.add("1,1,2,3");
+               list.add("2,4,5,6");
+               JavaRDD<String> javaRddString = sc.parallelize(list);
+
+               JavaRDD<Row> javaRddRow = javaRddString.map(new 
CommaSeparatedValueStringToRow());
+               SQLContext sqlContext = new SQLContext(sc);
+               List<StructField> fields = new ArrayList<StructField>();
+               fields.add(DataTypes.createStructField("ID", 
DataTypes.StringType, true));
+               fields.add(DataTypes.createStructField("C1", 
DataTypes.StringType, true));
+               fields.add(DataTypes.createStructField("C2", 
DataTypes.StringType, true));
+               fields.add(DataTypes.createStructField("C3", 
DataTypes.StringType, true));
+               StructType schema = DataTypes.createStructType(fields);
+               DataFrame dataFrame = sqlContext.createDataFrame(javaRddRow, 
schema);
+
+               MatrixMetadata mm = new 
MatrixMetadata(MatrixFormat.DF_DOUBLES_WITH_ID_COLUMN);
+
+               Script script = dml("print('M[1,1]: ' + 
as.scalar(M[1,1]));").in("M", dataFrame, mm);
+               setExpectedStdOut("M[1,1]: 1.0");
+               ml.execute(script);
+       }
+
+       @Test
+       public void testDataFrameSumPYDMLDoublesWithIDColumnSortCheck() {
+               System.out.println("MLContextTest - DataFrame sum PYDML ID, 
doubles with ID column sort check");
+
+               List<String> list = new ArrayList<String>();
+               list.add("3,7,8,9");
+               list.add("1,1,2,3");
+               list.add("2,4,5,6");
+               JavaRDD<String> javaRddString = sc.parallelize(list);
+
+               JavaRDD<Row> javaRddRow = javaRddString.map(new 
CommaSeparatedValueStringToRow());
+               SQLContext sqlContext = new SQLContext(sc);
+               List<StructField> fields = new ArrayList<StructField>();
+               fields.add(DataTypes.createStructField("ID", 
DataTypes.StringType, true));
+               fields.add(DataTypes.createStructField("C1", 
DataTypes.StringType, true));
+               fields.add(DataTypes.createStructField("C2", 
DataTypes.StringType, true));
+               fields.add(DataTypes.createStructField("C3", 
DataTypes.StringType, true));
+               StructType schema = DataTypes.createStructType(fields);
+               DataFrame dataFrame = sqlContext.createDataFrame(javaRddRow, 
schema);
+
+               MatrixMetadata mm = new 
MatrixMetadata(MatrixFormat.DF_DOUBLES_WITH_ID_COLUMN);
+
+               Script script = pydml("print('M[0,0]: ' + 
scalar(M[0,0]))").in("M", dataFrame, mm);
+               setExpectedStdOut("M[0,0]: 1.0");
+               ml.execute(script);
+       }
+
+       @Test
+       public void testDataFrameSumDMLVectorWithIDColumn() {
+               System.out.println("MLContextTest - DataFrame sum DML, vector 
with ID column");
+
+               List<Tuple2<Double, Vector>> list = new 
ArrayList<Tuple2<Double, Vector>>();
+               list.add(new Tuple2<Double, Vector>(1.0, Vectors.dense(1.0, 
2.0, 3.0)));
+               list.add(new Tuple2<Double, Vector>(2.0, Vectors.dense(4.0, 
5.0, 6.0)));
+               list.add(new Tuple2<Double, Vector>(3.0, Vectors.dense(7.0, 
8.0, 9.0)));
+               JavaRDD<Tuple2<Double, Vector>> javaRddTuple = 
sc.parallelize(list);
+
+               JavaRDD<Row> javaRddRow = javaRddTuple.map(new 
DoubleVectorRow());
+               SQLContext sqlContext = new SQLContext(sc);
+               List<StructField> fields = new ArrayList<StructField>();
+               fields.add(DataTypes.createStructField("ID", 
DataTypes.StringType, true));
+               fields.add(DataTypes.createStructField("C1", new VectorUDT(), 
true));
+               StructType schema = DataTypes.createStructType(fields);
+               DataFrame dataFrame = sqlContext.createDataFrame(javaRddRow, 
schema);
+
+               MatrixMetadata mm = new 
MatrixMetadata(MatrixFormat.DF_VECTOR_WITH_ID_COLUMN);
+
+               Script script = dml("print('sum: ' + sum(M));").in("M", 
dataFrame, mm);
+               setExpectedStdOut("sum: 45.0");
+               ml.execute(script);
+       }
+
+       @Test
+       public void testDataFrameSumPYDMLVectorWithIDColumn() {
+               System.out.println("MLContextTest - DataFrame sum PYDML, vector 
with ID column");
+
+               List<Tuple2<Double, Vector>> list = new 
ArrayList<Tuple2<Double, Vector>>();
+               list.add(new Tuple2<Double, Vector>(1.0, Vectors.dense(1.0, 
2.0, 3.0)));
+               list.add(new Tuple2<Double, Vector>(2.0, Vectors.dense(4.0, 
5.0, 6.0)));
+               list.add(new Tuple2<Double, Vector>(3.0, Vectors.dense(7.0, 
8.0, 9.0)));
+               JavaRDD<Tuple2<Double, Vector>> javaRddTuple = 
sc.parallelize(list);
+
+               JavaRDD<Row> javaRddRow = javaRddTuple.map(new 
DoubleVectorRow());
+               SQLContext sqlContext = new SQLContext(sc);
+               List<StructField> fields = new ArrayList<StructField>();
+               fields.add(DataTypes.createStructField("ID", 
DataTypes.StringType, true));
+               fields.add(DataTypes.createStructField("C1", new VectorUDT(), 
true));
+               StructType schema = DataTypes.createStructType(fields);
+               DataFrame dataFrame = sqlContext.createDataFrame(javaRddRow, 
schema);
+
+               MatrixMetadata mm = new 
MatrixMetadata(MatrixFormat.DF_VECTOR_WITH_ID_COLUMN);
+
+               Script script = dml("print('sum: ' + sum(M))").in("M", 
dataFrame, mm);
+               setExpectedStdOut("sum: 45.0");
+               ml.execute(script);
+       }
+
+       @Test
+       public void testDataFrameSumDMLVectorWithNoIDColumn() {
+               System.out.println("MLContextTest - DataFrame sum DML, vector 
with no ID column");
+
+               List<Vector> list = new ArrayList<Vector>();
+               list.add(Vectors.dense(1.0, 2.0, 3.0));
+               list.add(Vectors.dense(4.0, 5.0, 6.0));
+               list.add(Vectors.dense(7.0, 8.0, 9.0));
+               JavaRDD<Vector> javaRddVector = sc.parallelize(list);
+
+               JavaRDD<Row> javaRddRow = javaRddVector.map(new VectorRow());
+               SQLContext sqlContext = new SQLContext(sc);
+               List<StructField> fields = new ArrayList<StructField>();
+               fields.add(DataTypes.createStructField("C1", new VectorUDT(), 
true));
+               StructType schema = DataTypes.createStructType(fields);
+               DataFrame dataFrame = sqlContext.createDataFrame(javaRddRow, 
schema);
+
+               MatrixMetadata mm = new 
MatrixMetadata(MatrixFormat.DF_VECTOR_WITH_NO_ID_COLUMN);
+
+               Script script = dml("print('sum: ' + sum(M));").in("M", 
dataFrame, mm);
+               setExpectedStdOut("sum: 45.0");
+               ml.execute(script);
+       }
+
+       @Test
+       public void testDataFrameSumPYDMLVectorWithNoIDColumn() {
+               System.out.println("MLContextTest - DataFrame sum PYDML, vector 
with no ID column");
+
+               List<Vector> list = new ArrayList<Vector>();
+               list.add(Vectors.dense(1.0, 2.0, 3.0));
+               list.add(Vectors.dense(4.0, 5.0, 6.0));
+               list.add(Vectors.dense(7.0, 8.0, 9.0));
+               JavaRDD<Vector> javaRddVector = sc.parallelize(list);
+
+               JavaRDD<Row> javaRddRow = javaRddVector.map(new VectorRow());
+               SQLContext sqlContext = new SQLContext(sc);
+               List<StructField> fields = new ArrayList<StructField>();
+               fields.add(DataTypes.createStructField("C1", new VectorUDT(), 
true));
+               StructType schema = DataTypes.createStructType(fields);
+               DataFrame dataFrame = sqlContext.createDataFrame(javaRddRow, 
schema);
+
+               MatrixMetadata mm = new 
MatrixMetadata(MatrixFormat.DF_VECTOR_WITH_NO_ID_COLUMN);
+
+               Script script = dml("print('sum: ' + sum(M))").in("M", 
dataFrame, mm);
+               setExpectedStdOut("sum: 45.0");
+               ml.execute(script);
+       }
+
+       static class DoubleVectorRow implements Function<Tuple2<Double, 
Vector>, Row> {
+               private static final long serialVersionUID = 
3605080559931384163L;
+
+               @Override
+               public Row call(Tuple2<Double, Vector> tup) throws Exception {
+                       Double doub = tup._1();
+                       Vector vect = tup._2();
+                       return RowFactory.create(doub, vect);
+               }
+       }
+
+       static class VectorRow implements Function<Vector, Row> {
+               private static final long serialVersionUID = 
7077761802433569068L;
+
+               @Override
+               public Row call(Vector vect) throws Exception {
+                       return RowFactory.create(vect);
+               }
+       }
+
        static class CommaSeparatedValueStringToRow implements Function<String, 
Row> {
                private static final long serialVersionUID = 
-7871020122671747808L;
 
+               @Override
                public Row call(String str) throws Exception {
                        String[] fields = str.split(",");
                        return RowFactory.create((Object[]) fields);
@@ -1032,6 +1262,158 @@ public class MLContextTest extends AutomatedTestBase {
        }
 
        @Test
+       public void testOutputDataFrameDMLVectorWithIDColumn() {
+               System.out.println("MLContextTest - output DataFrame DML, 
vector with ID column");
+
+               String s = "M = matrix('1 2 3 4', rows=2, cols=2);";
+               Script script = dml(s).out("M");
+               MLResults results = ml.execute(script);
+               DataFrame dataFrame = 
results.getDataFrameVectorWithIDColumn("M");
+               List<Row> list = dataFrame.collectAsList();
+
+               Row row1 = list.get(0);
+               Assert.assertEquals(1.0, row1.getDouble(0), 0.0);
+               Assert.assertArrayEquals(new double[] { 1.0, 2.0 }, ((Vector) 
row1.get(1)).toArray(), 0.0);
+
+               Row row2 = list.get(1);
+               Assert.assertEquals(2.0, row2.getDouble(0), 0.0);
+               Assert.assertArrayEquals(new double[] { 3.0, 4.0 }, ((Vector) 
row2.get(1)).toArray(), 0.0);
+       }
+
+       @Test
+       public void testOutputDataFramePYDMLVectorWithIDColumn() {
+               System.out.println("MLContextTest - output DataFrame PYDML, 
vector with ID column");
+
+               String s = "M = full('1 2 3 4', rows=2, cols=2)";
+               Script script = pydml(s).out("M");
+               MLResults results = ml.execute(script);
+               DataFrame dataFrame = 
results.getDataFrameVectorWithIDColumn("M");
+               List<Row> list = dataFrame.collectAsList();
+
+               Row row1 = list.get(0);
+               Assert.assertEquals(1.0, row1.getDouble(0), 0.0);
+               Assert.assertArrayEquals(new double[] { 1.0, 2.0 }, ((Vector) 
row1.get(1)).toArray(), 0.0);
+
+               Row row2 = list.get(1);
+               Assert.assertEquals(2.0, row2.getDouble(0), 0.0);
+               Assert.assertArrayEquals(new double[] { 3.0, 4.0 }, ((Vector) 
row2.get(1)).toArray(), 0.0);
+       }
+
+       @Test
+       public void testOutputDataFrameDMLVectorNoIDColumn() {
+               System.out.println("MLContextTest - output DataFrame DML, 
vector no ID column");
+
+               String s = "M = matrix('1 2 3 4', rows=2, cols=2);";
+               Script script = dml(s).out("M");
+               MLResults results = ml.execute(script);
+               DataFrame dataFrame = results.getDataFrameVectorNoIDColumn("M");
+               List<Row> list = dataFrame.collectAsList();
+
+               Row row1 = list.get(0);
+               Assert.assertArrayEquals(new double[] { 1.0, 2.0 }, ((Vector) 
row1.get(0)).toArray(), 0.0);
+
+               Row row2 = list.get(1);
+               Assert.assertArrayEquals(new double[] { 3.0, 4.0 }, ((Vector) 
row2.get(0)).toArray(), 0.0);
+       }
+
+       @Test
+       public void testOutputDataFramePYDMLVectorNoIDColumn() {
+               System.out.println("MLContextTest - output DataFrame PYDML, 
vector no ID column");
+
+               String s = "M = full('1 2 3 4', rows=2, cols=2)";
+               Script script = pydml(s).out("M");
+               MLResults results = ml.execute(script);
+               DataFrame dataFrame = results.getDataFrameVectorNoIDColumn("M");
+               List<Row> list = dataFrame.collectAsList();
+
+               Row row1 = list.get(0);
+               Assert.assertArrayEquals(new double[] { 1.0, 2.0 }, ((Vector) 
row1.get(0)).toArray(), 0.0);
+
+               Row row2 = list.get(1);
+               Assert.assertArrayEquals(new double[] { 3.0, 4.0 }, ((Vector) 
row2.get(0)).toArray(), 0.0);
+       }
+
+       @Test
+       public void testOutputDataFrameDMLDoublesWithIDColumn() {
+               System.out.println("MLContextTest - output DataFrame DML, 
doubles with ID column");
+
+               String s = "M = matrix('1 2 3 4', rows=2, cols=2);";
+               Script script = dml(s).out("M");
+               MLResults results = ml.execute(script);
+               DataFrame dataFrame = 
results.getDataFrameDoubleWithIDColumn("M");
+               List<Row> list = dataFrame.collectAsList();
+
+               Row row1 = list.get(0);
+               Assert.assertEquals(1.0, row1.getDouble(0), 0.0);
+               Assert.assertEquals(1.0, row1.getDouble(1), 0.0);
+               Assert.assertEquals(2.0, row1.getDouble(2), 0.0);
+
+               Row row2 = list.get(1);
+               Assert.assertEquals(2.0, row2.getDouble(0), 0.0);
+               Assert.assertEquals(3.0, row2.getDouble(1), 0.0);
+               Assert.assertEquals(4.0, row2.getDouble(2), 0.0);
+       }
+
+       @Test
+       public void testOutputDataFramePYDMLDoublesWithIDColumn() {
+               System.out.println("MLContextTest - output DataFrame PYDML, 
doubles with ID column");
+
+               String s = "M = full('1 2 3 4', rows=2, cols=2)";
+               Script script = pydml(s).out("M");
+               MLResults results = ml.execute(script);
+               DataFrame dataFrame = 
results.getDataFrameDoubleWithIDColumn("M");
+               List<Row> list = dataFrame.collectAsList();
+
+               Row row1 = list.get(0);
+               Assert.assertEquals(1.0, row1.getDouble(0), 0.0);
+               Assert.assertEquals(1.0, row1.getDouble(1), 0.0);
+               Assert.assertEquals(2.0, row1.getDouble(2), 0.0);
+
+               Row row2 = list.get(1);
+               Assert.assertEquals(2.0, row2.getDouble(0), 0.0);
+               Assert.assertEquals(3.0, row2.getDouble(1), 0.0);
+               Assert.assertEquals(4.0, row2.getDouble(2), 0.0);
+       }
+
+       @Test
+       public void testOutputDataFrameDMLDoublesNoIDColumn() {
+               System.out.println("MLContextTest - output DataFrame DML, 
doubles no ID column");
+
+               String s = "M = matrix('1 2 3 4', rows=2, cols=2);";
+               Script script = dml(s).out("M");
+               MLResults results = ml.execute(script);
+               DataFrame dataFrame = results.getDataFrameDoubleNoIDColumn("M");
+               List<Row> list = dataFrame.collectAsList();
+
+               Row row1 = list.get(0);
+               Assert.assertEquals(1.0, row1.getDouble(0), 0.0);
+               Assert.assertEquals(2.0, row1.getDouble(1), 0.0);
+
+               Row row2 = list.get(1);
+               Assert.assertEquals(3.0, row2.getDouble(0), 0.0);
+               Assert.assertEquals(4.0, row2.getDouble(1), 0.0);
+       }
+
+       @Test
+       public void testOutputDataFramePYDMLDoublesNoIDColumn() {
+               System.out.println("MLContextTest - output DataFrame PYDML, 
doubles no ID column");
+
+               String s = "M = full('1 2 3 4', rows=2, cols=2)";
+               Script script = pydml(s).out("M");
+               MLResults results = ml.execute(script);
+               DataFrame dataFrame = results.getDataFrameDoubleNoIDColumn("M");
+               List<Row> list = dataFrame.collectAsList();
+
+               Row row1 = list.get(0);
+               Assert.assertEquals(1.0, row1.getDouble(0), 0.0);
+               Assert.assertEquals(2.0, row1.getDouble(1), 0.0);
+
+               Row row2 = list.get(1);
+               Assert.assertEquals(3.0, row2.getDouble(0), 0.0);
+               Assert.assertEquals(4.0, row2.getDouble(1), 0.0);
+       }
+
+       @Test
        public void testTwoScriptsDML() {
                System.out.println("MLContextTest - two scripts with inputs and 
outputs DML");
 

Reply via email to