Repository: incubator-systemml Updated Branches: refs/heads/master 9f12b5c66 -> 97dee8fba
[SYSTEMML-834] Improve MLContext DataFrame support Closes #218. Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/97dee8fb Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/97dee8fb Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/97dee8fb Branch: refs/heads/master Commit: 97dee8fba7252f9b868e23d69f23c36053f48445 Parents: 9f12b5c Author: Deron Eriksson <[email protected]> Authored: Thu Aug 25 14:33:00 2016 -0700 Committer: Deron Eriksson <[email protected]> Committed: Thu Aug 25 14:33:00 2016 -0700 ---------------------------------------------------------------------- .../sysml/api/mlcontext/BinaryBlockMatrix.java | 21 +- .../api/mlcontext/MLContextConversionUtil.java | 66 +++- .../apache/sysml/api/mlcontext/MLResults.java | 136 ++++++- .../org/apache/sysml/api/mlcontext/Matrix.java | 46 ++- .../sysml/api/mlcontext/MatrixFormat.java | 54 ++- .../integration/mlcontext/MLContextTest.java | 390 ++++++++++++++++++- 6 files changed, 689 insertions(+), 24 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/97dee8fb/src/main/java/org/apache/sysml/api/mlcontext/BinaryBlockMatrix.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/api/mlcontext/BinaryBlockMatrix.java b/src/main/java/org/apache/sysml/api/mlcontext/BinaryBlockMatrix.java index ea6fcf0..b13669d 100644 --- a/src/main/java/org/apache/sysml/api/mlcontext/BinaryBlockMatrix.java +++ b/src/main/java/org/apache/sysml/api/mlcontext/BinaryBlockMatrix.java @@ -99,12 +99,21 @@ public class BinaryBlockMatrix { public JavaPairRDD<MatrixIndexes, MatrixBlock> getBinaryBlocks() { return binaryBlocks; } - - public MatrixBlock getMatrixBlock() throws DMLRuntimeException { - MatrixCharacteristics mc = getMatrixCharacteristics(); - MatrixBlock mb = SparkExecutionContext.toMatrixBlock(binaryBlocks, (int) mc.getRows(), (int) mc.getCols(), - mc.getRowsPerBlock(), mc.getColsPerBlock(), mc.getNonZeros()); - return mb; + + /** + * Obtain a SystemML binary-block matrix as a {@code MatrixBlock} + * + * @return the SystemML binary-block matrix as a {@code MatrixBlock} + */ + public MatrixBlock getMatrixBlock() { + try { + MatrixCharacteristics mc = getMatrixCharacteristics(); + MatrixBlock mb = SparkExecutionContext.toMatrixBlock(binaryBlocks, (int) mc.getRows(), (int) mc.getCols(), + mc.getRowsPerBlock(), mc.getColsPerBlock(), mc.getNonZeros()); + return mb; + } catch (DMLRuntimeException e) { + throw new MLContextException("Exception while getting MatrixBlock from binary-block matrix", e); + } } /** http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/97dee8fb/src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java b/src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java index 33a5a3c..3a482ef 100644 --- a/src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java +++ b/src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java @@ -33,6 +33,7 @@ import org.apache.spark.SparkContext; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.rdd.RDD; import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.Row; @@ -311,7 +312,14 @@ public class MLContextConversionUtil { } else { matrixCharacteristics = new MatrixCharacteristics(); } - determineDataFrameDimensionsIfNeeded(dataFrame, matrixCharacteristics); + + if (isDataFrameWithIDColumn(matrixMetadata)) { + dataFrame = dataFrame.sort("ID").drop("ID"); + } + + boolean isVectorBasedDataFrame = isVectorBasedDataFrame(matrixMetadata); + + determineDataFrameDimensionsIfNeeded(dataFrame, matrixCharacteristics, isVectorBasedDataFrame); if (matrixMetadata != null) { // so external reference can be updated with the metadata matrixMetadata.setMatrixCharacteristics(matrixCharacteristics); @@ -320,12 +328,50 @@ public class MLContextConversionUtil { JavaRDD<Row> javaRDD = dataFrame.javaRDD(); JavaPairRDD<Row, Long> prepinput = javaRDD.zipWithIndex(); JavaPairRDD<MatrixIndexes, MatrixBlock> out = prepinput.mapPartitionsToPair(new DataFrameToBinaryBlockFunction( - matrixCharacteristics, false)); + matrixCharacteristics, isVectorBasedDataFrame)); out = RDDAggregateUtils.mergeByKey(out); return out; } /** + * Return whether or not the DataFrame has an ID column. + * + * @param matrixMetadata + * the matrix metadata + * @return {@code true} if the DataFrame has an ID column, {@code false} + * otherwise. + */ + public static boolean isDataFrameWithIDColumn(MatrixMetadata matrixMetadata) { + if (matrixMetadata == null) { + return false; + } + MatrixFormat matrixFormat = matrixMetadata.getMatrixFormat(); + if (matrixFormat == null) { + return false; + } + return matrixFormat.hasIDColumn(); + } + + /** + * Return whether or not the DataFrame is vector-based. + * + * @param matrixMetadata + * the matrix metadata + * @return {@code true} if the DataFrame is vector-based, {@code false} + * otherwise. + */ + public static boolean isVectorBasedDataFrame(MatrixMetadata matrixMetadata) { + if (matrixMetadata == null) { + return false; + } + MatrixFormat matrixFormat = matrixMetadata.getMatrixFormat(); + if (matrixFormat == null) { + return false; + } + return matrixFormat.isVectorBased(); + } + + /** * If the {@code DataFrame} dimensions aren't present in the * {@code MatrixCharacteristics} metadata, determine the dimensions and * place them in the {@code MatrixCharacteristics} metadata. @@ -334,20 +380,28 @@ public class MLContextConversionUtil { * the Spark {@code DataFrame} * @param matrixCharacteristics * the matrix metadata + * @param vectorBased + * is the DataFrame vector-based */ public static void determineDataFrameDimensionsIfNeeded(DataFrame dataFrame, - MatrixCharacteristics matrixCharacteristics) { + MatrixCharacteristics matrixCharacteristics, boolean vectorBased) { if (!matrixCharacteristics.dimsKnown(true)) { - // only available to the new MLContext API, not the old API MLContext activeMLContext = (MLContext) MLContextProxy.getActiveMLContext(); SparkContext sparkContext = activeMLContext.getSparkContext(); @SuppressWarnings("resource") JavaSparkContext javaSparkContext = new JavaSparkContext(sparkContext); Accumulator<Double> aNnz = javaSparkContext.accumulator(0L); - JavaRDD<Row> javaRDD = dataFrame.javaRDD().map(new DataFrameAnalysisFunction(aNnz, false)); + JavaRDD<Row> javaRDD = dataFrame.javaRDD().map(new DataFrameAnalysisFunction(aNnz, vectorBased)); long numRows = javaRDD.count(); - long numColumns = dataFrame.columns().length; + long numColumns; + if (vectorBased) { + Vector v = (Vector) javaRDD.first().get(0); + numColumns = v.size(); + } else { + numColumns = dataFrame.columns().length; + } + long numNonZeros = UtilFunctions.toLong(aNnz.value()); matrixCharacteristics.set(numRows, numColumns, matrixCharacteristics.getRowsPerBlock(), matrixCharacteristics.getColsPerBlock(), numNonZeros); http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/97dee8fb/src/main/java/org/apache/sysml/api/mlcontext/MLResults.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/api/mlcontext/MLResults.java b/src/main/java/org/apache/sysml/api/mlcontext/MLResults.java index 31798e0..dbc8f5d 100644 --- a/src/main/java/org/apache/sysml/api/mlcontext/MLResults.java +++ b/src/main/java/org/apache/sysml/api/mlcontext/MLResults.java @@ -236,7 +236,7 @@ public class MLResults { } /** - * Obtain an output as a {@code DataFrame} of doubles. + * Obtain an output as a {@code DataFrame} of doubles with an ID column. * <p> * The following matrix in DML: * </p> @@ -245,13 +245,13 @@ public class MLResults { * <p> * is equivalent to the following {@code DataFrame} of doubles: * </p> - * <code>[0.0,1.0,2.0] - * <br>[1.0,3.0,4.0] + * <code>[1.0,1.0,2.0] + * <br>[2.0,3.0,4.0] * </code> * * @param outputName * the name of the output - * @return the output as a {@code DataFrame} of doubles + * @return the output as a {@code DataFrame} of doubles with an ID column */ public DataFrame getDataFrame(String outputName) { MatrixObject mo = getMatrixObject(outputName); @@ -259,6 +259,35 @@ public class MLResults { return df; } + /** + * Obtain an output as a {@code DataFrame} of doubles or vectors with an ID + * column. + * <p> + * The following matrix in DML: + * </p> + * <code>M = full('1 2 3 4', rows=2, cols=2); + * </code> + * <p> + * is equivalent to the following {@code DataFrame} of doubles: + * </p> + * <code>[1.0,1.0,2.0] + * <br>[2.0,3.0,4.0] + * </code> + * <p> + * or the following {@code DataFrame} of vectors: + * </p> + * <code>[1.0,[1.0,2.0]] + * <br>[2.0,[3.0,4.0]] + * </code> + * + * @param outputName + * the name of the output + * @param isVectorDF + * {@code true} for a vector {@code DataFrame}, {@code false} for + * a double {@code DataFrame} + * @return the output as a {@code DataFrame} of doubles or vectors with an + * ID column + */ public DataFrame getDataFrame(String outputName, boolean isVectorDF) { MatrixObject mo = getMatrixObject(outputName); DataFrame df = MLContextConversionUtil.matrixObjectToDataFrame(mo, sparkExecutionContext, isVectorDF); @@ -266,6 +295,104 @@ public class MLResults { } /** + * Obtain an output as a {@code DataFrame} of doubles with an ID column. + * <p> + * The following matrix in DML: + * </p> + * <code>M = full('1 2 3 4', rows=2, cols=2); + * </code> + * <p> + * is equivalent to the following {@code DataFrame} of doubles: + * </p> + * <code>[1.0,1.0,2.0] + * <br>[2.0,3.0,4.0] + * </code> + * + * @param outputName + * the name of the output + * @return the output as a {@code DataFrame} of doubles with an ID column + */ + public DataFrame getDataFrameDoubleWithIDColumn(String outputName) { + MatrixObject mo = getMatrixObject(outputName); + DataFrame df = MLContextConversionUtil.matrixObjectToDataFrame(mo, sparkExecutionContext, false); + return df; + } + + /** + * Obtain an output as a {@code DataFrame} of vectors with an ID column. + * <p> + * The following matrix in DML: + * </p> + * <code>M = full('1 2 3 4', rows=2, cols=2); + * </code> + * <p> + * is equivalent to the following {@code DataFrame} of vectors: + * </p> + * <code>[1.0,[1.0,2.0]] + * <br>[2.0,[3.0,4.0]] + * </code> + * + * @param outputName + * the name of the output + * @return the output as a {@code DataFrame} of vectors with an ID column + */ + public DataFrame getDataFrameVectorWithIDColumn(String outputName) { + MatrixObject mo = getMatrixObject(outputName); + DataFrame df = MLContextConversionUtil.matrixObjectToDataFrame(mo, sparkExecutionContext, true); + return df; + } + + /** + * Obtain an output as a {@code DataFrame} of doubles with no ID column. + * <p> + * The following matrix in DML: + * </p> + * <code>M = full('1 2 3 4', rows=2, cols=2); + * </code> + * <p> + * is equivalent to the following {@code DataFrame} of doubles: + * </p> + * <code>[1.0,2.0] + * <br>[3.0,4.0] + * </code> + * + * @param outputName + * the name of the output + * @return the output as a {@code DataFrame} of doubles with no ID column + */ + public DataFrame getDataFrameDoubleNoIDColumn(String outputName) { + MatrixObject mo = getMatrixObject(outputName); + DataFrame df = MLContextConversionUtil.matrixObjectToDataFrame(mo, sparkExecutionContext, false); + df = df.sort("ID").drop("ID"); + return df; + } + + /** + * Obtain an output as a {@code DataFrame} of vectors with no ID column. + * <p> + * The following matrix in DML: + * </p> + * <code>M = full('1 2 3 4', rows=2, cols=2); + * </code> + * <p> + * is equivalent to the following {@code DataFrame} of vectors: + * </p> + * <code>[[1.0,2.0]] + * <br>[[3.0,4.0]] + * </code> + * + * @param outputName + * the name of the output + * @return the output as a {@code DataFrame} of vectors with no ID column + */ + public DataFrame getDataFrameVectorNoIDColumn(String outputName) { + MatrixObject mo = getMatrixObject(outputName); + DataFrame df = MLContextConversionUtil.matrixObjectToDataFrame(mo, sparkExecutionContext, true); + df = df.sort("ID").drop("ID"); + return df; + } + + /** * Obtain an output as a {@code Matrix}. * * @param outputName @@ -278,7 +405,6 @@ public class MLResults { return matrix; } - /** * Obtain an output as a {@code BinaryBlockMatrix}. * http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/97dee8fb/src/main/java/org/apache/sysml/api/mlcontext/Matrix.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/api/mlcontext/Matrix.java b/src/main/java/org/apache/sysml/api/mlcontext/Matrix.java index 3ee41b7..abd785c 100644 --- a/src/main/java/org/apache/sysml/api/mlcontext/Matrix.java +++ b/src/main/java/org/apache/sysml/api/mlcontext/Matrix.java @@ -103,9 +103,9 @@ public class Matrix { } /** - * Obtain the matrix as a {@code DataFrame} + * Obtain the matrix as a {@code DataFrame} of doubles with an ID column * - * @return the matrix as a {@code DataFrame} + * @return the matrix as a {@code DataFrame} of doubles with an ID column */ public DataFrame asDataFrame() { DataFrame df = MLContextConversionUtil.matrixObjectToDataFrame(matrixObject, sparkExecutionContext, false); @@ -113,6 +113,48 @@ public class Matrix { } /** + * Obtain the matrix as a {@code DataFrame} of doubles with an ID column + * + * @return the matrix as a {@code DataFrame} of doubles with an ID column + */ + public DataFrame asDataFrameDoubleWithIDColumn() { + DataFrame df = MLContextConversionUtil.matrixObjectToDataFrame(matrixObject, sparkExecutionContext, false); + return df; + } + + /** + * Obtain the matrix as a {@code DataFrame} of doubles with no ID column + * + * @return the matrix as a {@code DataFrame} of doubles with no ID column + */ + public DataFrame asDataFrameDoubleNoIDColumn() { + DataFrame df = MLContextConversionUtil.matrixObjectToDataFrame(matrixObject, sparkExecutionContext, false); + df = df.sort("ID").drop("ID"); + return df; + } + + /** + * Obtain the matrix as a {@code DataFrame} of vectors with an ID column + * + * @return the matrix as a {@code DataFrame} of vectors with an ID column + */ + public DataFrame asDataFrameVectorWithIDColumn() { + DataFrame df = MLContextConversionUtil.matrixObjectToDataFrame(matrixObject, sparkExecutionContext, true); + return df; + } + + /** + * Obtain the matrix as a {@code DataFrame} of vectors with no ID column + * + * @return the matrix as a {@code DataFrame} of vectors with no ID column + */ + public DataFrame asDataFrameVectorNoIDColumn() { + DataFrame df = MLContextConversionUtil.matrixObjectToDataFrame(matrixObject, sparkExecutionContext, true); + df = df.sort("ID").drop("ID"); + return df; + } + + /** * Obtain the matrix as a {@code BinaryBlockMatrix} * * @return the matrix as a {@code BinaryBlockMatrix} http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/97dee8fb/src/main/java/org/apache/sysml/api/mlcontext/MatrixFormat.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/api/mlcontext/MatrixFormat.java b/src/main/java/org/apache/sysml/api/mlcontext/MatrixFormat.java index 50ed634..a7ac395 100644 --- a/src/main/java/org/apache/sysml/api/mlcontext/MatrixFormat.java +++ b/src/main/java/org/apache/sysml/api/mlcontext/MatrixFormat.java @@ -34,6 +34,58 @@ public enum MatrixFormat { * (I J V) format (sparse). I and J represent matrix coordinates and V * represents the value. The I J and V values are space-separated. */ - IJV; + IJV, + + /** + * DataFrame of doubles with an ID column. + */ + DF_DOUBLES_WITH_ID_COLUMN, + + /** + * DataFrame of doubles with no ID column. + */ + DF_DOUBLES_WITH_NO_ID_COLUMN, + + /** + * Vector DataFrame with an ID column. + */ + DF_VECTOR_WITH_ID_COLUMN, + + /** + * Vector DataFrame with no ID column. + */ + DF_VECTOR_WITH_NO_ID_COLUMN; + + /** + * Is the matrix format vector-based? + * + * @return {@code true} if matrix is a vector-based DataFrame, {@code false} + * otherwise. + */ + public boolean isVectorBased() { + if (this == DF_VECTOR_WITH_ID_COLUMN) { + return true; + } else if (this == DF_VECTOR_WITH_NO_ID_COLUMN) { + return true; + } else { + return false; + } + } + + /** + * Does the DataFrame have an ID column? + * + * @return {@code true} if the DataFrame has an ID column, {@code false} + * otherwise. + */ + public boolean hasIDColumn() { + if (this == DF_DOUBLES_WITH_ID_COLUMN) { + return true; + } else if (this == DF_VECTOR_WITH_ID_COLUMN) { + return true; + } else { + return false; + } + } } http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/97dee8fb/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTest.java b/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTest.java index e6e1046..7be657b 100644 --- a/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTest.java +++ b/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTest.java @@ -46,6 +46,9 @@ import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.VectorUDT; +import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.rdd.RDD; import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.Row; @@ -497,8 +500,8 @@ public class MLContextTest extends AutomatedTestBase { } @Test - public void testDataFrameSumDML() { - System.out.println("MLContextTest - DataFrame sum DML"); + public void testDataFrameSumDMLDoublesWithNoIDColumn() { + System.out.println("MLContextTest - DataFrame sum DML, doubles with no ID column"); List<String> list = new ArrayList<String>(); list.add("10,20,30"); @@ -521,8 +524,8 @@ public class MLContextTest extends AutomatedTestBase { } @Test - public void testDataFrameSumPYDML() { - System.out.println("MLContextTest - DataFrame sum PYDML"); + public void testDataFrameSumPYDMLDoublesWithNoIDColumn() { + System.out.println("MLContextTest - DataFrame sum PYDML, doubles with no ID column"); List<String> list = new ArrayList<String>(); list.add("10,20,30"); @@ -544,9 +547,236 @@ public class MLContextTest extends AutomatedTestBase { ml.execute(script); } + @Test + public void testDataFrameSumDMLDoublesWithIDColumn() { + System.out.println("MLContextTest - DataFrame sum DML, doubles with ID column"); + + List<String> list = new ArrayList<String>(); + list.add("1,1,2,3"); + list.add("2,4,5,6"); + list.add("3,7,8,9"); + JavaRDD<String> javaRddString = sc.parallelize(list); + + JavaRDD<Row> javaRddRow = javaRddString.map(new CommaSeparatedValueStringToRow()); + SQLContext sqlContext = new SQLContext(sc); + List<StructField> fields = new ArrayList<StructField>(); + fields.add(DataTypes.createStructField("ID", DataTypes.StringType, true)); + fields.add(DataTypes.createStructField("C1", DataTypes.StringType, true)); + fields.add(DataTypes.createStructField("C2", DataTypes.StringType, true)); + fields.add(DataTypes.createStructField("C3", DataTypes.StringType, true)); + StructType schema = DataTypes.createStructType(fields); + DataFrame dataFrame = sqlContext.createDataFrame(javaRddRow, schema); + + MatrixMetadata mm = new MatrixMetadata(MatrixFormat.DF_DOUBLES_WITH_ID_COLUMN); + + Script script = dml("print('sum: ' + sum(M));").in("M", dataFrame, mm); + setExpectedStdOut("sum: 45.0"); + ml.execute(script); + } + + @Test + public void testDataFrameSumPYDMLDoublesWithIDColumn() { + System.out.println("MLContextTest - DataFrame sum PYDML, doubles with ID column"); + + List<String> list = new ArrayList<String>(); + list.add("1,1,2,3"); + list.add("2,4,5,6"); + list.add("3,7,8,9"); + JavaRDD<String> javaRddString = sc.parallelize(list); + + JavaRDD<Row> javaRddRow = javaRddString.map(new CommaSeparatedValueStringToRow()); + SQLContext sqlContext = new SQLContext(sc); + List<StructField> fields = new ArrayList<StructField>(); + fields.add(DataTypes.createStructField("ID", DataTypes.StringType, true)); + fields.add(DataTypes.createStructField("C1", DataTypes.StringType, true)); + fields.add(DataTypes.createStructField("C2", DataTypes.StringType, true)); + fields.add(DataTypes.createStructField("C3", DataTypes.StringType, true)); + StructType schema = DataTypes.createStructType(fields); + DataFrame dataFrame = sqlContext.createDataFrame(javaRddRow, schema); + + MatrixMetadata mm = new MatrixMetadata(MatrixFormat.DF_DOUBLES_WITH_ID_COLUMN); + + Script script = pydml("print('sum: ' + sum(M))").in("M", dataFrame, mm); + setExpectedStdOut("sum: 45.0"); + ml.execute(script); + } + + @Test + public void testDataFrameSumDMLDoublesWithIDColumnSortCheck() { + System.out.println("MLContextTest - DataFrame sum DML, doubles with ID column sort check"); + + List<String> list = new ArrayList<String>(); + list.add("3,7,8,9"); + list.add("1,1,2,3"); + list.add("2,4,5,6"); + JavaRDD<String> javaRddString = sc.parallelize(list); + + JavaRDD<Row> javaRddRow = javaRddString.map(new CommaSeparatedValueStringToRow()); + SQLContext sqlContext = new SQLContext(sc); + List<StructField> fields = new ArrayList<StructField>(); + fields.add(DataTypes.createStructField("ID", DataTypes.StringType, true)); + fields.add(DataTypes.createStructField("C1", DataTypes.StringType, true)); + fields.add(DataTypes.createStructField("C2", DataTypes.StringType, true)); + fields.add(DataTypes.createStructField("C3", DataTypes.StringType, true)); + StructType schema = DataTypes.createStructType(fields); + DataFrame dataFrame = sqlContext.createDataFrame(javaRddRow, schema); + + MatrixMetadata mm = new MatrixMetadata(MatrixFormat.DF_DOUBLES_WITH_ID_COLUMN); + + Script script = dml("print('M[1,1]: ' + as.scalar(M[1,1]));").in("M", dataFrame, mm); + setExpectedStdOut("M[1,1]: 1.0"); + ml.execute(script); + } + + @Test + public void testDataFrameSumPYDMLDoublesWithIDColumnSortCheck() { + System.out.println("MLContextTest - DataFrame sum PYDML ID, doubles with ID column sort check"); + + List<String> list = new ArrayList<String>(); + list.add("3,7,8,9"); + list.add("1,1,2,3"); + list.add("2,4,5,6"); + JavaRDD<String> javaRddString = sc.parallelize(list); + + JavaRDD<Row> javaRddRow = javaRddString.map(new CommaSeparatedValueStringToRow()); + SQLContext sqlContext = new SQLContext(sc); + List<StructField> fields = new ArrayList<StructField>(); + fields.add(DataTypes.createStructField("ID", DataTypes.StringType, true)); + fields.add(DataTypes.createStructField("C1", DataTypes.StringType, true)); + fields.add(DataTypes.createStructField("C2", DataTypes.StringType, true)); + fields.add(DataTypes.createStructField("C3", DataTypes.StringType, true)); + StructType schema = DataTypes.createStructType(fields); + DataFrame dataFrame = sqlContext.createDataFrame(javaRddRow, schema); + + MatrixMetadata mm = new MatrixMetadata(MatrixFormat.DF_DOUBLES_WITH_ID_COLUMN); + + Script script = pydml("print('M[0,0]: ' + scalar(M[0,0]))").in("M", dataFrame, mm); + setExpectedStdOut("M[0,0]: 1.0"); + ml.execute(script); + } + + @Test + public void testDataFrameSumDMLVectorWithIDColumn() { + System.out.println("MLContextTest - DataFrame sum DML, vector with ID column"); + + List<Tuple2<Double, Vector>> list = new ArrayList<Tuple2<Double, Vector>>(); + list.add(new Tuple2<Double, Vector>(1.0, Vectors.dense(1.0, 2.0, 3.0))); + list.add(new Tuple2<Double, Vector>(2.0, Vectors.dense(4.0, 5.0, 6.0))); + list.add(new Tuple2<Double, Vector>(3.0, Vectors.dense(7.0, 8.0, 9.0))); + JavaRDD<Tuple2<Double, Vector>> javaRddTuple = sc.parallelize(list); + + JavaRDD<Row> javaRddRow = javaRddTuple.map(new DoubleVectorRow()); + SQLContext sqlContext = new SQLContext(sc); + List<StructField> fields = new ArrayList<StructField>(); + fields.add(DataTypes.createStructField("ID", DataTypes.StringType, true)); + fields.add(DataTypes.createStructField("C1", new VectorUDT(), true)); + StructType schema = DataTypes.createStructType(fields); + DataFrame dataFrame = sqlContext.createDataFrame(javaRddRow, schema); + + MatrixMetadata mm = new MatrixMetadata(MatrixFormat.DF_VECTOR_WITH_ID_COLUMN); + + Script script = dml("print('sum: ' + sum(M));").in("M", dataFrame, mm); + setExpectedStdOut("sum: 45.0"); + ml.execute(script); + } + + @Test + public void testDataFrameSumPYDMLVectorWithIDColumn() { + System.out.println("MLContextTest - DataFrame sum PYDML, vector with ID column"); + + List<Tuple2<Double, Vector>> list = new ArrayList<Tuple2<Double, Vector>>(); + list.add(new Tuple2<Double, Vector>(1.0, Vectors.dense(1.0, 2.0, 3.0))); + list.add(new Tuple2<Double, Vector>(2.0, Vectors.dense(4.0, 5.0, 6.0))); + list.add(new Tuple2<Double, Vector>(3.0, Vectors.dense(7.0, 8.0, 9.0))); + JavaRDD<Tuple2<Double, Vector>> javaRddTuple = sc.parallelize(list); + + JavaRDD<Row> javaRddRow = javaRddTuple.map(new DoubleVectorRow()); + SQLContext sqlContext = new SQLContext(sc); + List<StructField> fields = new ArrayList<StructField>(); + fields.add(DataTypes.createStructField("ID", DataTypes.StringType, true)); + fields.add(DataTypes.createStructField("C1", new VectorUDT(), true)); + StructType schema = DataTypes.createStructType(fields); + DataFrame dataFrame = sqlContext.createDataFrame(javaRddRow, schema); + + MatrixMetadata mm = new MatrixMetadata(MatrixFormat.DF_VECTOR_WITH_ID_COLUMN); + + Script script = dml("print('sum: ' + sum(M))").in("M", dataFrame, mm); + setExpectedStdOut("sum: 45.0"); + ml.execute(script); + } + + @Test + public void testDataFrameSumDMLVectorWithNoIDColumn() { + System.out.println("MLContextTest - DataFrame sum DML, vector with no ID column"); + + List<Vector> list = new ArrayList<Vector>(); + list.add(Vectors.dense(1.0, 2.0, 3.0)); + list.add(Vectors.dense(4.0, 5.0, 6.0)); + list.add(Vectors.dense(7.0, 8.0, 9.0)); + JavaRDD<Vector> javaRddVector = sc.parallelize(list); + + JavaRDD<Row> javaRddRow = javaRddVector.map(new VectorRow()); + SQLContext sqlContext = new SQLContext(sc); + List<StructField> fields = new ArrayList<StructField>(); + fields.add(DataTypes.createStructField("C1", new VectorUDT(), true)); + StructType schema = DataTypes.createStructType(fields); + DataFrame dataFrame = sqlContext.createDataFrame(javaRddRow, schema); + + MatrixMetadata mm = new MatrixMetadata(MatrixFormat.DF_VECTOR_WITH_NO_ID_COLUMN); + + Script script = dml("print('sum: ' + sum(M));").in("M", dataFrame, mm); + setExpectedStdOut("sum: 45.0"); + ml.execute(script); + } + + @Test + public void testDataFrameSumPYDMLVectorWithNoIDColumn() { + System.out.println("MLContextTest - DataFrame sum PYDML, vector with no ID column"); + + List<Vector> list = new ArrayList<Vector>(); + list.add(Vectors.dense(1.0, 2.0, 3.0)); + list.add(Vectors.dense(4.0, 5.0, 6.0)); + list.add(Vectors.dense(7.0, 8.0, 9.0)); + JavaRDD<Vector> javaRddVector = sc.parallelize(list); + + JavaRDD<Row> javaRddRow = javaRddVector.map(new VectorRow()); + SQLContext sqlContext = new SQLContext(sc); + List<StructField> fields = new ArrayList<StructField>(); + fields.add(DataTypes.createStructField("C1", new VectorUDT(), true)); + StructType schema = DataTypes.createStructType(fields); + DataFrame dataFrame = sqlContext.createDataFrame(javaRddRow, schema); + + MatrixMetadata mm = new MatrixMetadata(MatrixFormat.DF_VECTOR_WITH_NO_ID_COLUMN); + + Script script = dml("print('sum: ' + sum(M))").in("M", dataFrame, mm); + setExpectedStdOut("sum: 45.0"); + ml.execute(script); + } + + static class DoubleVectorRow implements Function<Tuple2<Double, Vector>, Row> { + private static final long serialVersionUID = 3605080559931384163L; + + @Override + public Row call(Tuple2<Double, Vector> tup) throws Exception { + Double doub = tup._1(); + Vector vect = tup._2(); + return RowFactory.create(doub, vect); + } + } + + static class VectorRow implements Function<Vector, Row> { + private static final long serialVersionUID = 7077761802433569068L; + + @Override + public Row call(Vector vect) throws Exception { + return RowFactory.create(vect); + } + } + static class CommaSeparatedValueStringToRow implements Function<String, Row> { private static final long serialVersionUID = -7871020122671747808L; + @Override public Row call(String str) throws Exception { String[] fields = str.split(","); return RowFactory.create((Object[]) fields); @@ -1032,6 +1262,158 @@ public class MLContextTest extends AutomatedTestBase { } @Test + public void testOutputDataFrameDMLVectorWithIDColumn() { + System.out.println("MLContextTest - output DataFrame DML, vector with ID column"); + + String s = "M = matrix('1 2 3 4', rows=2, cols=2);"; + Script script = dml(s).out("M"); + MLResults results = ml.execute(script); + DataFrame dataFrame = results.getDataFrameVectorWithIDColumn("M"); + List<Row> list = dataFrame.collectAsList(); + + Row row1 = list.get(0); + Assert.assertEquals(1.0, row1.getDouble(0), 0.0); + Assert.assertArrayEquals(new double[] { 1.0, 2.0 }, ((Vector) row1.get(1)).toArray(), 0.0); + + Row row2 = list.get(1); + Assert.assertEquals(2.0, row2.getDouble(0), 0.0); + Assert.assertArrayEquals(new double[] { 3.0, 4.0 }, ((Vector) row2.get(1)).toArray(), 0.0); + } + + @Test + public void testOutputDataFramePYDMLVectorWithIDColumn() { + System.out.println("MLContextTest - output DataFrame PYDML, vector with ID column"); + + String s = "M = full('1 2 3 4', rows=2, cols=2)"; + Script script = pydml(s).out("M"); + MLResults results = ml.execute(script); + DataFrame dataFrame = results.getDataFrameVectorWithIDColumn("M"); + List<Row> list = dataFrame.collectAsList(); + + Row row1 = list.get(0); + Assert.assertEquals(1.0, row1.getDouble(0), 0.0); + Assert.assertArrayEquals(new double[] { 1.0, 2.0 }, ((Vector) row1.get(1)).toArray(), 0.0); + + Row row2 = list.get(1); + Assert.assertEquals(2.0, row2.getDouble(0), 0.0); + Assert.assertArrayEquals(new double[] { 3.0, 4.0 }, ((Vector) row2.get(1)).toArray(), 0.0); + } + + @Test + public void testOutputDataFrameDMLVectorNoIDColumn() { + System.out.println("MLContextTest - output DataFrame DML, vector no ID column"); + + String s = "M = matrix('1 2 3 4', rows=2, cols=2);"; + Script script = dml(s).out("M"); + MLResults results = ml.execute(script); + DataFrame dataFrame = results.getDataFrameVectorNoIDColumn("M"); + List<Row> list = dataFrame.collectAsList(); + + Row row1 = list.get(0); + Assert.assertArrayEquals(new double[] { 1.0, 2.0 }, ((Vector) row1.get(0)).toArray(), 0.0); + + Row row2 = list.get(1); + Assert.assertArrayEquals(new double[] { 3.0, 4.0 }, ((Vector) row2.get(0)).toArray(), 0.0); + } + + @Test + public void testOutputDataFramePYDMLVectorNoIDColumn() { + System.out.println("MLContextTest - output DataFrame PYDML, vector no ID column"); + + String s = "M = full('1 2 3 4', rows=2, cols=2)"; + Script script = pydml(s).out("M"); + MLResults results = ml.execute(script); + DataFrame dataFrame = results.getDataFrameVectorNoIDColumn("M"); + List<Row> list = dataFrame.collectAsList(); + + Row row1 = list.get(0); + Assert.assertArrayEquals(new double[] { 1.0, 2.0 }, ((Vector) row1.get(0)).toArray(), 0.0); + + Row row2 = list.get(1); + Assert.assertArrayEquals(new double[] { 3.0, 4.0 }, ((Vector) row2.get(0)).toArray(), 0.0); + } + + @Test + public void testOutputDataFrameDMLDoublesWithIDColumn() { + System.out.println("MLContextTest - output DataFrame DML, doubles with ID column"); + + String s = "M = matrix('1 2 3 4', rows=2, cols=2);"; + Script script = dml(s).out("M"); + MLResults results = ml.execute(script); + DataFrame dataFrame = results.getDataFrameDoubleWithIDColumn("M"); + List<Row> list = dataFrame.collectAsList(); + + Row row1 = list.get(0); + Assert.assertEquals(1.0, row1.getDouble(0), 0.0); + Assert.assertEquals(1.0, row1.getDouble(1), 0.0); + Assert.assertEquals(2.0, row1.getDouble(2), 0.0); + + Row row2 = list.get(1); + Assert.assertEquals(2.0, row2.getDouble(0), 0.0); + Assert.assertEquals(3.0, row2.getDouble(1), 0.0); + Assert.assertEquals(4.0, row2.getDouble(2), 0.0); + } + + @Test + public void testOutputDataFramePYDMLDoublesWithIDColumn() { + System.out.println("MLContextTest - output DataFrame PYDML, doubles with ID column"); + + String s = "M = full('1 2 3 4', rows=2, cols=2)"; + Script script = pydml(s).out("M"); + MLResults results = ml.execute(script); + DataFrame dataFrame = results.getDataFrameDoubleWithIDColumn("M"); + List<Row> list = dataFrame.collectAsList(); + + Row row1 = list.get(0); + Assert.assertEquals(1.0, row1.getDouble(0), 0.0); + Assert.assertEquals(1.0, row1.getDouble(1), 0.0); + Assert.assertEquals(2.0, row1.getDouble(2), 0.0); + + Row row2 = list.get(1); + Assert.assertEquals(2.0, row2.getDouble(0), 0.0); + Assert.assertEquals(3.0, row2.getDouble(1), 0.0); + Assert.assertEquals(4.0, row2.getDouble(2), 0.0); + } + + @Test + public void testOutputDataFrameDMLDoublesNoIDColumn() { + System.out.println("MLContextTest - output DataFrame DML, doubles no ID column"); + + String s = "M = matrix('1 2 3 4', rows=2, cols=2);"; + Script script = dml(s).out("M"); + MLResults results = ml.execute(script); + DataFrame dataFrame = results.getDataFrameDoubleNoIDColumn("M"); + List<Row> list = dataFrame.collectAsList(); + + Row row1 = list.get(0); + Assert.assertEquals(1.0, row1.getDouble(0), 0.0); + Assert.assertEquals(2.0, row1.getDouble(1), 0.0); + + Row row2 = list.get(1); + Assert.assertEquals(3.0, row2.getDouble(0), 0.0); + Assert.assertEquals(4.0, row2.getDouble(1), 0.0); + } + + @Test + public void testOutputDataFramePYDMLDoublesNoIDColumn() { + System.out.println("MLContextTest - output DataFrame PYDML, doubles no ID column"); + + String s = "M = full('1 2 3 4', rows=2, cols=2)"; + Script script = pydml(s).out("M"); + MLResults results = ml.execute(script); + DataFrame dataFrame = results.getDataFrameDoubleNoIDColumn("M"); + List<Row> list = dataFrame.collectAsList(); + + Row row1 = list.get(0); + Assert.assertEquals(1.0, row1.getDouble(0), 0.0); + Assert.assertEquals(2.0, row1.getDouble(1), 0.0); + + Row row2 = list.get(1); + Assert.assertEquals(3.0, row2.getDouble(0), 0.0); + Assert.assertEquals(4.0, row2.getDouble(1), 0.0); + } + + @Test public void testTwoScriptsDML() { System.out.println("MLContextTest - two scripts with inputs and outputs DML");
