Repository: incubator-systemml
Updated Branches:
  refs/heads/master 3877e3563 -> 01d643c67


[SYSTEMML-887] Automatic DataFrame type determination in MLContext

Closes #228.


Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: 
http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/01d643c6
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/01d643c6
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/01d643c6

Branch: refs/heads/master
Commit: 01d643c677603c48d3d24cdba06dabe9fdb6d910
Parents: 3877e35
Author: Deron Eriksson <[email protected]>
Authored: Tue Aug 30 13:53:09 2016 -0700
Committer: Deron Eriksson <[email protected]>
Committed: Tue Aug 30 13:53:09 2016 -0700

----------------------------------------------------------------------
 .../api/mlcontext/MLContextConversionUtil.java  |  94 ++++++---
 .../org/apache/sysml/api/mlcontext/Script.java  |  35 +++-
 .../integration/mlcontext/MLContextTest.java    | 201 ++++++++++++++++++-
 3 files changed, 302 insertions(+), 28 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/01d643c6/src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java 
b/src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java
index 0c98dea..63de638 100644
--- a/src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java
+++ b/src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java
@@ -38,6 +38,7 @@ import org.apache.spark.rdd.RDD;
 import org.apache.spark.sql.DataFrame;
 import org.apache.spark.sql.Row;
 import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.types.StructType;
 import org.apache.sysml.api.MLContextProxy;
 import org.apache.sysml.parser.Expression.ValueType;
 import org.apache.sysml.runtime.DMLRuntimeException;
@@ -107,14 +108,14 @@ public class MLContextConversionUtil {
                        if (matrixMetadata != null) {
                                matrixCharacteristics = 
matrixMetadata.asMatrixCharacteristics();
                        } else {
-                               matrixCharacteristics = new 
MatrixCharacteristics(matrixBlock.getNumRows(),
-                                               matrixBlock.getNumColumns(), 
MLContextUtil.defaultBlockSize(), MLContextUtil.defaultBlockSize());
+                               matrixCharacteristics = new 
MatrixCharacteristics(matrixBlock.getNumRows(), matrixBlock.getNumColumns(),
+                                               
MLContextUtil.defaultBlockSize(), MLContextUtil.defaultBlockSize());
                        }
 
                        MatrixFormatMetaData meta = new 
MatrixFormatMetaData(matrixCharacteristics,
                                        OutputInfo.BinaryBlockOutputInfo, 
InputInfo.BinaryBlockInputInfo);
-                       MatrixObject matrixObject = new 
MatrixObject(ValueType.DOUBLE, MLContextUtil.scratchSpace() + "/"
-                                       + variableName, meta);
+                       MatrixObject matrixObject = new 
MatrixObject(ValueType.DOUBLE,
+                                       MLContextUtil.scratchSpace() + "/" + 
variableName, meta);
                        matrixObject.acquireModify(matrixBlock);
                        matrixObject.release();
                        return matrixObject;
@@ -176,10 +177,10 @@ public class MLContextConversionUtil {
                        } else {
                                matrixCharacteristics = new 
MatrixCharacteristics();
                        }
-                       MatrixFormatMetaData mtd = new 
MatrixFormatMetaData(matrixCharacteristics,
-                                       OutputInfo.BinaryBlockOutputInfo, 
InputInfo.BinaryBlockInputInfo);
-                       MatrixObject matrixObject = new 
MatrixObject(ValueType.DOUBLE, MLContextUtil.scratchSpace() + "/"
-                                       + variableName, mtd);
+                       MatrixFormatMetaData mtd = new 
MatrixFormatMetaData(matrixCharacteristics, OutputInfo.BinaryBlockOutputInfo,
+                                       InputInfo.BinaryBlockInputInfo);
+                       MatrixObject matrixObject = new 
MatrixObject(ValueType.DOUBLE,
+                                       MLContextUtil.scratchSpace() + "/" + 
variableName, mtd);
                        matrixObject.acquireModify(matrixBlock);
                        matrixObject.release();
                        return matrixObject;
@@ -208,10 +209,9 @@ public class MLContextConversionUtil {
                        } else {
                                matrixCharacteristics = new 
MatrixCharacteristics();
                        }
-                       MatrixFormatMetaData mtd = new 
MatrixFormatMetaData(matrixCharacteristics,
-                                       OutputInfo.BinaryBlockOutputInfo, 
InputInfo.BinaryBlockInputInfo);
-                       FrameObject frameObject = new 
FrameObject(MLContextUtil.scratchSpace() + "/"
-                                       + variableName, mtd);
+                       MatrixFormatMetaData mtd = new 
MatrixFormatMetaData(matrixCharacteristics, OutputInfo.BinaryBlockOutputInfo,
+                                       InputInfo.BinaryBlockInputInfo);
+                       FrameObject frameObject = new 
FrameObject(MLContextUtil.scratchSpace() + "/" + variableName, mtd);
                        frameObject.acquireModify(frameBlock);
                        frameObject.release();
                        return frameObject;
@@ -263,9 +263,9 @@ public class MLContextConversionUtil {
 
                JavaPairRDD<MatrixIndexes, MatrixBlock> javaPairRdd = 
binaryBlocks.mapToPair(new CopyBlockPairFunction());
 
-               MatrixObject matrixObject = new MatrixObject(ValueType.DOUBLE, 
MLContextUtil.scratchSpace() + "/" + "temp_"
-                               + System.nanoTime(), new 
MatrixFormatMetaData(matrixCharacteristics, OutputInfo.BinaryBlockOutputInfo,
-                               InputInfo.BinaryBlockInputInfo));
+               MatrixObject matrixObject = new MatrixObject(ValueType.DOUBLE,
+                               MLContextUtil.scratchSpace() + "/" + "temp_" + 
System.nanoTime(), new MatrixFormatMetaData(
+                                               matrixCharacteristics, 
OutputInfo.BinaryBlockOutputInfo, InputInfo.BinaryBlockInputInfo));
                matrixObject.setRDDHandle(new RDDObject(javaPairRdd, 
variableName));
                return matrixObject;
        }
@@ -301,8 +301,8 @@ public class MLContextConversionUtil {
                if (matrixMetadata == null) {
                        matrixMetadata = new MatrixMetadata();
                }
-               JavaPairRDD<MatrixIndexes, MatrixBlock> binaryBlock = 
MLContextConversionUtil.dataFrameToBinaryBlocks(
-                               dataFrame, matrixMetadata);
+               JavaPairRDD<MatrixIndexes, MatrixBlock> binaryBlock = 
MLContextConversionUtil.dataFrameToBinaryBlocks(dataFrame,
+                               matrixMetadata);
                MatrixObject matrixObject = 
MLContextConversionUtil.binaryBlocksToMatrixObject(variableName, binaryBlock,
                                matrixMetadata);
                return matrixObject;
@@ -337,6 +337,8 @@ public class MLContextConversionUtil {
        public static JavaPairRDD<MatrixIndexes, MatrixBlock> 
dataFrameToBinaryBlocks(DataFrame dataFrame,
                        MatrixMetadata matrixMetadata) {
 
+               determineMatrixFormatIfNeeded(dataFrame, matrixMetadata);
+
                MatrixCharacteristics matrixCharacteristics;
                if (matrixMetadata != null) {
                        matrixCharacteristics = 
matrixMetadata.asMatrixCharacteristics();
@@ -361,13 +363,57 @@ public class MLContextConversionUtil {
 
                JavaRDD<Row> javaRDD = dataFrame.javaRDD();
                JavaPairRDD<Row, Long> prepinput = javaRDD.zipWithIndex();
-               JavaPairRDD<MatrixIndexes, MatrixBlock> out = 
prepinput.mapPartitionsToPair(new DataFrameToBinaryBlockFunction(
-                               matrixCharacteristics, isVectorBasedDataFrame));
+               JavaPairRDD<MatrixIndexes, MatrixBlock> out = prepinput
+                               .mapPartitionsToPair(new 
DataFrameToBinaryBlockFunction(matrixCharacteristics, isVectorBasedDataFrame));
                out = RDDAggregateUtils.mergeByKey(out);
                return out;
        }
 
        /**
+        * If the MatrixFormat of the DataFrame has not been explicitly 
specified,
+        * attempt to determine the proper MatrixFormat.
+        * 
+        * @param dataFrame
+        *            the Spark {@code DataFrame}
+        * @param matrixMetadata
+        *            the matrix metadata, if available
+        */
+       public static void determineMatrixFormatIfNeeded(DataFrame dataFrame, 
MatrixMetadata matrixMetadata) {
+               MatrixFormat matrixFormat = matrixMetadata.getMatrixFormat();
+               if (matrixFormat != null) {
+                       return;
+               }
+               StructType schema = dataFrame.schema();
+               boolean hasID = false;
+               try {
+                       schema.fieldIndex("ID");
+                       hasID = true;
+               } catch (IllegalArgumentException iae) {
+               }
+               Row firstRow = dataFrame.first();
+               MatrixFormat mf = null;
+               if (hasID) {
+                       Object object = firstRow.get(1);
+                       if (object instanceof Vector) {
+                               mf = MatrixFormat.DF_VECTOR_WITH_ID_COLUMN;
+                       } else {
+                               mf = MatrixFormat.DF_DOUBLES_WITH_ID_COLUMN;
+                       }
+               } else {
+                       Object object = firstRow.get(0);
+                       if (object instanceof Vector) {
+                               mf = MatrixFormat.DF_VECTOR_WITH_NO_ID_COLUMN;
+                       } else {
+                               mf = MatrixFormat.DF_DOUBLES_WITH_NO_ID_COLUMN;
+                       }
+               }
+               if (mf == null) {
+                       throw new MLContextException("DataFrame format not 
recognized as an accepted SystemML MatrixFormat");
+               }
+               matrixMetadata.setMatrixFormat(mf);
+       }
+
+       /**
         * Return whether or not the DataFrame has an ID column.
         * 
         * @param matrixMetadata
@@ -475,8 +521,8 @@ public class MLContextConversionUtil {
                } else {
                        matrixCharacteristics = new MatrixCharacteristics();
                }
-               MatrixObject matrixObject = new MatrixObject(ValueType.DOUBLE, 
null, new MatrixFormatMetaData(
-                               matrixCharacteristics, 
OutputInfo.CSVOutputInfo, InputInfo.CSVInputInfo));
+               MatrixObject matrixObject = new MatrixObject(ValueType.DOUBLE, 
null,
+                               new MatrixFormatMetaData(matrixCharacteristics, 
OutputInfo.CSVOutputInfo, InputInfo.CSVInputInfo));
                JavaPairRDD<LongWritable, Text> javaPairRDD2 = 
javaPairRDD.mapToPair(new CopyTextInputFunction());
                matrixObject.setRDDHandle(new RDDObject(javaPairRDD2, 
variableName));
                return matrixObject;
@@ -747,7 +793,8 @@ public class MLContextConversionUtil {
                        matrixObject.release();
                        return list;
                } catch (CacheException e) {
-                       throw new MLContextException("Cache exception while 
converting matrix object to List<String> CSV format", e);
+                       throw new MLContextException("Cache exception while 
converting matrix object to List<String> CSV format",
+                                       e);
                }
        }
 
@@ -800,7 +847,8 @@ public class MLContextConversionUtil {
                        matrixObject.release();
                        return list;
                } catch (CacheException e) {
-                       throw new MLContextException("Cache exception while 
converting matrix object to List<String> IJV format", e);
+                       throw new MLContextException("Cache exception while 
converting matrix object to List<String> IJV format",
+                                       e);
                }
        }
 

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/01d643c6/src/main/java/org/apache/sysml/api/mlcontext/Script.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/mlcontext/Script.java 
b/src/main/java/org/apache/sysml/api/mlcontext/Script.java
index 28667cf..dcd8fc3 100644
--- a/src/main/java/org/apache/sysml/api/mlcontext/Script.java
+++ b/src/main/java/org/apache/sysml/api/mlcontext/Script.java
@@ -305,7 +305,7 @@ public class Script {
         * @return {@code this} Script object to allow chaining of methods
         */
        public Script in(String name, Object value) {
-               return in(name, value, null);
+               return in(name, value, (MatrixMetadata) null);
        }
 
        public Script input(String name, Object value) {
@@ -320,6 +320,39 @@ public class Script {
         *            name of the input
         * @param value
         *            value of the input
+        * @param matrixFormat
+        *            optional matrix format
+        * @return {@code this} Script object to allow chaining of methods
+        */
+       public Script in(String name, Object value, MatrixFormat matrixFormat) {
+               MatrixMetadata matrixMetadata = new 
MatrixMetadata(matrixFormat);
+               return in(name, value, matrixMetadata);
+       }
+
+       /**
+        * Register an input (parameter ($) or variable) with optional matrix
+        * metadata.
+        *
+        * @param name
+        *            name of the input
+        * @param value
+        *            value of the input
+        * @param matrixFormat
+        *            optional matrix format
+        * @return {@code this} Script object to allow chaining of methods
+        */
+       public Script input(String name, Object value, MatrixFormat 
matrixFormat) {
+               return in(name, value, matrixFormat);
+       }
+
+       /**
+        * Register an input (parameter ($) or variable) with optional matrix
+        * metadata.
+        *
+        * @param name
+        *            name of the input
+        * @param value
+        *            value of the input
         * @param matrixMetadata
         *            optional matrix metadata
         * @return {@code this} Script object to allow chaining of methods

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/01d643c6/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTest.java 
b/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTest.java
index 7be657b..e65dfe7 100644
--- 
a/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTest.java
+++ 
b/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTest.java
@@ -518,7 +518,9 @@ public class MLContextTest extends AutomatedTestBase {
                StructType schema = DataTypes.createStructType(fields);
                DataFrame dataFrame = sqlContext.createDataFrame(javaRddRow, 
schema);
 
-               Script script = dml("print('sum: ' + sum(M));").in("M", 
dataFrame);
+               MatrixMetadata mm = new 
MatrixMetadata(MatrixFormat.DF_DOUBLES_WITH_NO_ID_COLUMN);
+
+               Script script = dml("print('sum: ' + sum(M));").in("M", 
dataFrame, mm);
                setExpectedStdOut("sum: 450.0");
                ml.execute(script);
        }
@@ -542,7 +544,9 @@ public class MLContextTest extends AutomatedTestBase {
                StructType schema = DataTypes.createStructType(fields);
                DataFrame dataFrame = sqlContext.createDataFrame(javaRddRow, 
schema);
 
-               Script script = pydml("print('sum: ' + sum(M))").in("M", 
dataFrame);
+               MatrixMetadata mm = new 
MatrixMetadata(MatrixFormat.DF_DOUBLES_WITH_NO_ID_COLUMN);
+
+               Script script = pydml("print('sum: ' + sum(M))").in("M", 
dataFrame, mm);
                setExpectedStdOut("sum: 450.0");
                ml.execute(script);
        }
@@ -2065,12 +2069,200 @@ public class MLContextTest extends AutomatedTestBase {
                ml.execute(script);
        }
 
+       @Test
+       public void testDataFrameSumDMLDoublesWithNoIDColumnNoFormatSpecified() 
{
+               System.out.println("MLContextTest - DataFrame sum DML, doubles 
with no ID column, no format specified");
+
+               List<String> list = new ArrayList<String>();
+               list.add("2,2,2");
+               list.add("3,3,3");
+               list.add("4,4,4");
+               JavaRDD<String> javaRddString = sc.parallelize(list);
+
+               JavaRDD<Row> javaRddRow = javaRddString.map(new 
CommaSeparatedValueStringToRow());
+               SQLContext sqlContext = new SQLContext(sc);
+               List<StructField> fields = new ArrayList<StructField>();
+               fields.add(DataTypes.createStructField("C1", 
DataTypes.StringType, true));
+               fields.add(DataTypes.createStructField("C2", 
DataTypes.StringType, true));
+               fields.add(DataTypes.createStructField("C3", 
DataTypes.StringType, true));
+               StructType schema = DataTypes.createStructType(fields);
+               DataFrame dataFrame = sqlContext.createDataFrame(javaRddRow, 
schema);
+
+               Script script = dml("print('sum: ' + sum(M));").in("M", 
dataFrame);
+               setExpectedStdOut("sum: 27.0");
+               ml.execute(script);
+       }
+
+       @Test
+       public void 
testDataFrameSumPYDMLDoublesWithNoIDColumnNoFormatSpecified() {
+               System.out.println("MLContextTest - DataFrame sum PYDML, 
doubles with no ID column, no format specified");
+
+               List<String> list = new ArrayList<String>();
+               list.add("2,2,2");
+               list.add("3,3,3");
+               list.add("4,4,4");
+               JavaRDD<String> javaRddString = sc.parallelize(list);
+
+               JavaRDD<Row> javaRddRow = javaRddString.map(new 
CommaSeparatedValueStringToRow());
+               SQLContext sqlContext = new SQLContext(sc);
+               List<StructField> fields = new ArrayList<StructField>();
+               fields.add(DataTypes.createStructField("C1", 
DataTypes.StringType, true));
+               fields.add(DataTypes.createStructField("C2", 
DataTypes.StringType, true));
+               fields.add(DataTypes.createStructField("C3", 
DataTypes.StringType, true));
+               StructType schema = DataTypes.createStructType(fields);
+               DataFrame dataFrame = sqlContext.createDataFrame(javaRddRow, 
schema);
+
+               Script script = pydml("print('sum: ' + sum(M))").in("M", 
dataFrame);
+               setExpectedStdOut("sum: 27.0");
+               ml.execute(script);
+       }
+
+       @Test
+       public void testDataFrameSumDMLDoublesWithIDColumnNoFormatSpecified() {
+               System.out.println("MLContextTest - DataFrame sum DML, doubles 
with ID column, no format specified");
+
+               List<String> list = new ArrayList<String>();
+               list.add("1,2,2,2");
+               list.add("2,3,3,3");
+               list.add("3,4,4,4");
+               JavaRDD<String> javaRddString = sc.parallelize(list);
+
+               JavaRDD<Row> javaRddRow = javaRddString.map(new 
CommaSeparatedValueStringToRow());
+               SQLContext sqlContext = new SQLContext(sc);
+               List<StructField> fields = new ArrayList<StructField>();
+               fields.add(DataTypes.createStructField("ID", 
DataTypes.StringType, true));
+               fields.add(DataTypes.createStructField("C1", 
DataTypes.StringType, true));
+               fields.add(DataTypes.createStructField("C2", 
DataTypes.StringType, true));
+               fields.add(DataTypes.createStructField("C3", 
DataTypes.StringType, true));
+               StructType schema = DataTypes.createStructType(fields);
+               DataFrame dataFrame = sqlContext.createDataFrame(javaRddRow, 
schema);
+
+               Script script = dml("print('sum: ' + sum(M));").in("M", 
dataFrame);
+               setExpectedStdOut("sum: 27.0");
+               ml.execute(script);
+       }
+
+       @Test
+       public void testDataFrameSumPYDMLDoublesWithIDColumnNoFormatSpecified() 
{
+               System.out.println("MLContextTest - DataFrame sum PYDML, 
doubles with ID column, no format specified");
+
+               List<String> list = new ArrayList<String>();
+               list.add("1,2,2,2");
+               list.add("2,3,3,3");
+               list.add("3,4,4,4");
+               JavaRDD<String> javaRddString = sc.parallelize(list);
+
+               JavaRDD<Row> javaRddRow = javaRddString.map(new 
CommaSeparatedValueStringToRow());
+               SQLContext sqlContext = new SQLContext(sc);
+               List<StructField> fields = new ArrayList<StructField>();
+               fields.add(DataTypes.createStructField("ID", 
DataTypes.StringType, true));
+               fields.add(DataTypes.createStructField("C1", 
DataTypes.StringType, true));
+               fields.add(DataTypes.createStructField("C2", 
DataTypes.StringType, true));
+               fields.add(DataTypes.createStructField("C3", 
DataTypes.StringType, true));
+               StructType schema = DataTypes.createStructType(fields);
+               DataFrame dataFrame = sqlContext.createDataFrame(javaRddRow, 
schema);
+
+               Script script = pydml("print('sum: ' + sum(M))").in("M", 
dataFrame);
+               setExpectedStdOut("sum: 27.0");
+               ml.execute(script);
+       }
+
+       @Test
+       public void testDataFrameSumDMLVectorWithIDColumnNoFormatSpecified() {
+               System.out.println("MLContextTest - DataFrame sum DML, vector 
with ID column, no format specified");
+
+               List<Tuple2<Double, Vector>> list = new 
ArrayList<Tuple2<Double, Vector>>();
+               list.add(new Tuple2<Double, Vector>(1.0, Vectors.dense(1.0, 
2.0, 3.0)));
+               list.add(new Tuple2<Double, Vector>(2.0, Vectors.dense(4.0, 
5.0, 6.0)));
+               list.add(new Tuple2<Double, Vector>(3.0, Vectors.dense(7.0, 
8.0, 9.0)));
+               JavaRDD<Tuple2<Double, Vector>> javaRddTuple = 
sc.parallelize(list);
+
+               JavaRDD<Row> javaRddRow = javaRddTuple.map(new 
DoubleVectorRow());
+               SQLContext sqlContext = new SQLContext(sc);
+               List<StructField> fields = new ArrayList<StructField>();
+               fields.add(DataTypes.createStructField("ID", 
DataTypes.StringType, true));
+               fields.add(DataTypes.createStructField("C1", new VectorUDT(), 
true));
+               StructType schema = DataTypes.createStructType(fields);
+               DataFrame dataFrame = sqlContext.createDataFrame(javaRddRow, 
schema);
+
+               Script script = dml("print('sum: ' + sum(M));").in("M", 
dataFrame);
+               setExpectedStdOut("sum: 45.0");
+               ml.execute(script);
+       }
+
+       @Test
+       public void testDataFrameSumPYDMLVectorWithIDColumnNoFormatSpecified() {
+               System.out.println("MLContextTest - DataFrame sum PYDML, vector 
with ID column, no format specified");
+
+               List<Tuple2<Double, Vector>> list = new 
ArrayList<Tuple2<Double, Vector>>();
+               list.add(new Tuple2<Double, Vector>(1.0, Vectors.dense(1.0, 
2.0, 3.0)));
+               list.add(new Tuple2<Double, Vector>(2.0, Vectors.dense(4.0, 
5.0, 6.0)));
+               list.add(new Tuple2<Double, Vector>(3.0, Vectors.dense(7.0, 
8.0, 9.0)));
+               JavaRDD<Tuple2<Double, Vector>> javaRddTuple = 
sc.parallelize(list);
+
+               JavaRDD<Row> javaRddRow = javaRddTuple.map(new 
DoubleVectorRow());
+               SQLContext sqlContext = new SQLContext(sc);
+               List<StructField> fields = new ArrayList<StructField>();
+               fields.add(DataTypes.createStructField("ID", 
DataTypes.StringType, true));
+               fields.add(DataTypes.createStructField("C1", new VectorUDT(), 
true));
+               StructType schema = DataTypes.createStructType(fields);
+               DataFrame dataFrame = sqlContext.createDataFrame(javaRddRow, 
schema);
+
+               Script script = dml("print('sum: ' + sum(M))").in("M", 
dataFrame);
+               setExpectedStdOut("sum: 45.0");
+               ml.execute(script);
+       }
+
+       @Test
+       public void testDataFrameSumDMLVectorWithNoIDColumnNoFormatSpecified() {
+               System.out.println("MLContextTest - DataFrame sum DML, vector 
with no ID column, no format specified");
+
+               List<Vector> list = new ArrayList<Vector>();
+               list.add(Vectors.dense(1.0, 2.0, 3.0));
+               list.add(Vectors.dense(4.0, 5.0, 6.0));
+               list.add(Vectors.dense(7.0, 8.0, 9.0));
+               JavaRDD<Vector> javaRddVector = sc.parallelize(list);
+
+               JavaRDD<Row> javaRddRow = javaRddVector.map(new VectorRow());
+               SQLContext sqlContext = new SQLContext(sc);
+               List<StructField> fields = new ArrayList<StructField>();
+               fields.add(DataTypes.createStructField("C1", new VectorUDT(), 
true));
+               StructType schema = DataTypes.createStructType(fields);
+               DataFrame dataFrame = sqlContext.createDataFrame(javaRddRow, 
schema);
+
+               Script script = dml("print('sum: ' + sum(M));").in("M", 
dataFrame);
+               setExpectedStdOut("sum: 45.0");
+               ml.execute(script);
+       }
+
+       @Test
+       public void 
testDataFrameSumPYDMLVectorWithNoIDColumnNoFormatSpecified() {
+               System.out.println("MLContextTest - DataFrame sum PYDML, vector 
with no ID column, no format specified");
+
+               List<Vector> list = new ArrayList<Vector>();
+               list.add(Vectors.dense(1.0, 2.0, 3.0));
+               list.add(Vectors.dense(4.0, 5.0, 6.0));
+               list.add(Vectors.dense(7.0, 8.0, 9.0));
+               JavaRDD<Vector> javaRddVector = sc.parallelize(list);
+
+               JavaRDD<Row> javaRddRow = javaRddVector.map(new VectorRow());
+               SQLContext sqlContext = new SQLContext(sc);
+               List<StructField> fields = new ArrayList<StructField>();
+               fields.add(DataTypes.createStructField("C1", new VectorUDT(), 
true));
+               StructType schema = DataTypes.createStructType(fields);
+               DataFrame dataFrame = sqlContext.createDataFrame(javaRddRow, 
schema);
+
+               Script script = dml("print('sum: ' + sum(M))").in("M", 
dataFrame);
+               setExpectedStdOut("sum: 45.0");
+               ml.execute(script);
+       }
        // NOTE: Uncomment these tests once they work
 
        // @SuppressWarnings({ "rawtypes", "unchecked" })
        // @Test
        // public void testInputTupleSeqWithAndWithoutMetadataDML() {
-       // System.out.println("MLContextTest - Tuple sequence with and without 
metadata DML");
+       // System.out.println("MLContextTest - Tuple sequence with and without
+       // metadata DML");
        //
        // List<String> list1 = new ArrayList<String>();
        // list1.add("1,2");
@@ -2102,7 +2294,8 @@ public class MLContextTest extends AutomatedTestBase {
        // @SuppressWarnings({ "rawtypes", "unchecked" })
        // @Test
        // public void testInputTupleSeqWithAndWithoutMetadataPYDML() {
-       // System.out.println("MLContextTest - Tuple sequence with and without 
metadata PYDML");
+       // System.out.println("MLContextTest - Tuple sequence with and without
+       // metadata PYDML");
        //
        // List<String> list1 = new ArrayList<String>();
        // list1.add("1,2");

Reply via email to