Repository: incubator-systemml
Updated Branches:
  refs/heads/master 4bc6601d6 -> 0b472b09e


[SYSTEMML-568] Frame Schema support through MLContext

Closes 250


Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: 
http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/3957c0fa
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/3957c0fa
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/3957c0fa

Branch: refs/heads/master
Commit: 3957c0fa6aadb2213e96b6fa0d4f26271c711868
Parents: 61a6dcb
Author: Arvind Surve <ac...@yahoo.com>
Authored: Wed Sep 21 22:18:49 2016 -0700
Committer: Arvind Surve <ac...@yahoo.com>
Committed: Wed Sep 21 22:18:49 2016 -0700

----------------------------------------------------------------------
 .../api/mlcontext/MLContextConversionUtil.java  |  14 ++-
 .../controlprogram/caching/FrameObject.java     |  14 +++
 .../spark/utils/FrameRDDConverterUtils.java     |  11 +-
 .../mlcontext/MLContextFrameTest.java           | 116 +++++++++++++++----
 4 files changed, 124 insertions(+), 31 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/3957c0fa/src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java 
b/src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java
index 1adc089..aa0366d 100644
--- a/src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java
+++ b/src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java
@@ -193,7 +193,7 @@ public class MLContextConversionUtil {
                                        frameMetadata.asMatrixCharacteristics() 
: new MatrixCharacteristics();
                        MatrixFormatMetaData mtd = new MatrixFormatMetaData(mc, 
                                        OutputInfo.BinaryBlockOutputInfo, 
InputInfo.BinaryBlockInputInfo);
-                       FrameObject frameObject = new 
FrameObject(OptimizerUtils.getUniqueTempFileName(), mtd);
+                       FrameObject frameObject = new 
FrameObject(OptimizerUtils.getUniqueTempFileName(), mtd, 
frameMetadata.getFrameSchema().getSchema());
                        frameObject.acquireModify(frameBlock);
                        frameObject.release();
                        return frameObject;
@@ -282,7 +282,7 @@ public class MLContextConversionUtil {
                                frameMetadata.asMatrixCharacteristics() : new 
MatrixCharacteristics();
 
                FrameObject frameObject = new 
FrameObject(OptimizerUtils.getUniqueTempFileName(), 
-                               new MatrixFormatMetaData(mc, 
OutputInfo.BinaryBlockOutputInfo, InputInfo.BinaryBlockInputInfo));
+                               new MatrixFormatMetaData(mc, 
OutputInfo.BinaryBlockOutputInfo, InputInfo.BinaryBlockInputInfo), 
frameMetadata.getFrameSchema().getSchema());
                frameObject.setRDDHandle(new RDDObject(binaryBlocks, 
variableName));
                return frameObject;
        }
@@ -365,6 +365,12 @@ public class MLContextConversionUtil {
                                matrixCharacteristics.setDimension(rows, cols);
                                
frameMetadata.setMatrixCharacteristics(matrixCharacteristics);
                        }
+                       
+                       List<String> colnames = new ArrayList<String>();
+                       List<ValueType> fschema = new ArrayList<ValueType>();
+                       
FrameRDDConverterUtils.convertDFSchemaToFrameSchema(dataFrame.schema(), 
colnames, fschema, containsID); 
+                       frameMetadata.setFrameSchema(new FrameSchema(fschema));
+
                        JavaPairRDD<Long, FrameBlock> binaryBlock = 
FrameRDDConverterUtils.dataFrameToBinaryBlock(javaSparkContext,
                                        dataFrame, matrixCharacteristics, 
containsID);
 
@@ -598,7 +604,7 @@ public class MLContextConversionUtil {
                JavaSparkContext jsc = 
MLContextUtil.getJavaSparkContext((MLContext) 
MLContextProxy.getActiveMLContextForAPI());
 
                FrameObject frameObject = new 
FrameObject(OptimizerUtils.getUniqueTempFileName(), 
-                               new MatrixFormatMetaData(mc, 
OutputInfo.BinaryBlockOutputInfo, InputInfo.BinaryBlockInputInfo));
+                               new MatrixFormatMetaData(mc, 
OutputInfo.BinaryBlockOutputInfo, InputInfo.BinaryBlockInputInfo), 
frameMetadata.getFrameSchema().getSchema());
                JavaPairRDD<Long, FrameBlock> rdd;
                try {
                        rdd = FrameRDDConverterUtils.csvToBinaryBlock(jsc, 
javaPairRDDText, mc, 
@@ -659,7 +665,7 @@ public class MLContextConversionUtil {
                JavaSparkContext jsc = 
MLContextUtil.getJavaSparkContext((MLContext) 
MLContextProxy.getActiveMLContextForAPI());
 
                FrameObject frameObject = new 
FrameObject(OptimizerUtils.getUniqueTempFileName(), 
-                               new MatrixFormatMetaData(mc, 
OutputInfo.BinaryBlockOutputInfo, InputInfo.BinaryBlockInputInfo));
+                               new MatrixFormatMetaData(mc, 
OutputInfo.BinaryBlockOutputInfo, InputInfo.BinaryBlockInputInfo), 
frameMetadata.getFrameSchema().getSchema());
                JavaPairRDD<Long, FrameBlock> rdd;
                try {
                        List<ValueType> lschema = null;

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/3957c0fa/src/main/java/org/apache/sysml/runtime/controlprogram/caching/FrameObject.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/caching/FrameObject.java
 
b/src/main/java/org/apache/sysml/runtime/controlprogram/caching/FrameObject.java
index 1209064..8c5fe8b 100644
--- 
a/src/main/java/org/apache/sysml/runtime/controlprogram/caching/FrameObject.java
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/caching/FrameObject.java
@@ -79,6 +79,20 @@ public class FrameObject extends CacheableData<FrameBlock>
        }
        
        /**
+        * 
+        * @param fname
+        * @param meta
+        * @param schema
+        * 
+        */
+       public FrameObject(String fname, MetaData meta, List<ValueType> schema) 
{
+               this();
+               setFileName(fname);
+               setMetaData(meta);
+               setSchema(schema);
+       }
+       
+       /**
         * Copy constructor that copies meta data but NO data.
         * 
         * @param fo

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/3957c0fa/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/FrameRDDConverterUtils.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/FrameRDDConverterUtils.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/FrameRDDConverterUtils.java
index 351d559..ede0211 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/FrameRDDConverterUtils.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/FrameRDDConverterUtils.java
@@ -326,12 +326,9 @@ public class FrameRDDConverterUtils
        /**
         * 
         * @param sc
-        * @param input
-        * @param mcOut
-        * @param hasHeader
-        * @param delim
-        * @param fill
-        * @param missingValue
+        * @param df
+        * @param mc
+        * @param containsID
         * @return
         * @throws DMLRuntimeException
         */
@@ -889,7 +886,7 @@ public class FrameRDDConverterUtils
                        int cols = blk.getNumColumns();
                        for( int i=0; i<rows; i++ ) {
                                Object[] row = new Object[cols+1];
-                               row[0] = rowIndex++;
+                               row[0] = (double)rowIndex++;
                                for( int j=0; j<cols; j++ )
                                        row[j+1] = blk.get(i, j);
                                ret.add(RowFactory.create(row));

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/3957c0fa/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextFrameTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextFrameTest.java
 
b/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextFrameTest.java
index deac382..972e6ea 100644
--- 
a/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextFrameTest.java
+++ 
b/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextFrameTest.java
@@ -33,12 +33,14 @@ import org.apache.spark.api.java.JavaSparkContext;
 import org.apache.spark.rdd.RDD;
 import org.apache.spark.sql.DataFrame;
 import org.apache.spark.sql.Row;
+import org.apache.spark.sql.RowFactory;
 import org.apache.spark.sql.SQLContext;
 import org.apache.spark.sql.types.DataTypes;
 import org.apache.spark.sql.types.StructField;
 import org.apache.spark.sql.types.StructType;
 import org.apache.sysml.api.mlcontext.FrameFormat;
 import org.apache.sysml.api.mlcontext.FrameMetadata;
+import org.apache.sysml.api.mlcontext.FrameSchema;
 import org.apache.sysml.api.mlcontext.MLContext;
 import org.apache.sysml.api.mlcontext.MLResults;
 import org.apache.sysml.api.mlcontext.MatrixFormat;
@@ -181,6 +183,10 @@ public class MLContextFrameTest extends AutomatedTestBase {
                List<String> listB = new ArrayList<String>();
                FrameMetadata fmA = null, fmB = null;
                Script script = null;
+               List<ValueType> lschemaA = Arrays.asList(ValueType.INT, 
ValueType.STRING, ValueType.DOUBLE, ValueType.BOOLEAN);
+               FrameSchema fschemaA = new FrameSchema(lschemaA);
+               List<ValueType> lschemaB = Arrays.asList(ValueType.STRING, 
ValueType.DOUBLE, ValueType.BOOLEAN);
+               FrameSchema fschemaB = new FrameSchema(lschemaB);
 
                if (inputType != IO_TYPE.FILE) {
                        if (format == FrameFormat.CSV) {
@@ -191,8 +197,8 @@ public class MLContextFrameTest extends AutomatedTestBase {
                                listB.add("Str12,13.0,true");
                                listB.add("Str25,26.0,false");
 
-                               fmA = new FrameMetadata(FrameFormat.CSV, 3, 4);
-                               fmB = new FrameMetadata(FrameFormat.CSV, 2, 3);
+                               fmA = new FrameMetadata(FrameFormat.CSV, 
fschemaA, 3, 4);
+                               fmB = new FrameMetadata(FrameFormat.CSV, 
fschemaB, 2, 3);
                        } else if (format == FrameFormat.IJV) {
                                listA.add("1 1 1");
                                listA.add("1 2 Str2");
@@ -214,8 +220,8 @@ public class MLContextFrameTest extends AutomatedTestBase {
                                listB.add("2 2 26.0");
                                listB.add("2 3 false");
 
-                               fmA = new FrameMetadata(FrameFormat.IJV, 3, 4);
-                               fmB = new FrameMetadata(FrameFormat.IJV, 2, 3);
+                               fmA = new FrameMetadata(FrameFormat.IJV, 
fschemaA, 3, 4);
+                               fmB = new FrameMetadata(FrameFormat.IJV, 
fschemaB, 2, 3);
                        }
                        JavaRDD<String> javaRDDA = sc.parallelize(listA);
                        JavaRDD<String> javaRDDB = sc.parallelize(listB);
@@ -224,11 +230,6 @@ public class MLContextFrameTest extends AutomatedTestBase {
                                JavaRDD<Row> javaRddRowA = javaRDDA.map(new 
MLContextTest.CommaSeparatedValueStringToRow());
                                JavaRDD<Row> javaRddRowB = javaRDDB.map(new 
MLContextTest.CommaSeparatedValueStringToRow());
 
-                               ValueType[] schemaA = { ValueType.INT, 
ValueType.STRING, ValueType.DOUBLE, ValueType.BOOLEAN };
-                               List<ValueType> lschemaA = 
Arrays.asList(schemaA);
-                               ValueType[] schemaB = { ValueType.STRING, 
ValueType.DOUBLE, ValueType.BOOLEAN };
-                               List<ValueType> lschemaB = 
Arrays.asList(schemaB);
-
                                // Create DataFrame
                                SQLContext sqlContext = new SQLContext(sc);
                                StructType dfSchemaA = 
FrameRDDConverterUtils.convertFrameSchemaToDFSchema(lschemaA, false);
@@ -302,6 +303,24 @@ public class MLContextFrameTest extends AutomatedTestBase {
                }
 
                MLResults mlResults = ml.execute(script);
+               
+               //Validate output schema
+               List<ValueType> lschemaOutA = 
mlResults.getFrameObject("A").getSchema();
+               List<ValueType> lschemaOutC = 
mlResults.getFrameObject("C").getSchema();
+               if(inputType != IO_TYPE.FILE) {
+                       Assert.assertEquals(ValueType.INT, lschemaOutA.get(0));
+                       Assert.assertEquals(ValueType.STRING, 
lschemaOutA.get(1));
+                       Assert.assertEquals(ValueType.DOUBLE, 
lschemaOutA.get(2));
+                       Assert.assertEquals(ValueType.BOOLEAN, 
lschemaOutA.get(3));
+                       
+                       Assert.assertEquals(ValueType.STRING, 
lschemaOutC.get(0));
+                       Assert.assertEquals(ValueType.DOUBLE, 
lschemaOutC.get(1));
+               } else {
+                       for (int i=0; i < lschemaOutA.size(); i++)
+                               Assert.assertEquals(ValueType.STRING, 
lschemaOutA.get(i));
+                       for (int i=0; i < lschemaOutC.size(); i++)
+                               Assert.assertEquals(ValueType.STRING, 
lschemaOutC.get(i));
+               }
 
                if (outputType == IO_TYPE.JAVA_RDD_STR_CSV) {
 
@@ -370,30 +389,46 @@ public class MLContextFrameTest extends AutomatedTestBase 
{
                } else if (outputType == IO_TYPE.DATAFRAME) {
 
                        DataFrame dataFrameA = 
mlResults.getDataFrame("A").drop(RDDConverterUtils.DF_ID_COLUMN);
+                       StructType dfschemaA = dataFrameA.schema(); 
+                       StructField structTypeA = dfschemaA.apply(0);
+                       Assert.assertEquals(DataTypes.LongType, 
structTypeA.dataType());
+                       structTypeA = dfschemaA.apply(1);
+                       Assert.assertEquals(DataTypes.StringType, 
structTypeA.dataType());
+                       structTypeA = dfschemaA.apply(2);
+                       Assert.assertEquals(DataTypes.DoubleType, 
structTypeA.dataType());
+                       structTypeA = dfschemaA.apply(3);
+                       Assert.assertEquals(DataTypes.BooleanType, 
structTypeA.dataType());
+
                        List<Row> listAOut = dataFrameA.collectAsList();
 
                        Row row1 = listAOut.get(0);
-                       Assert.assertEquals("Mistmatch with expected value", 
"1", row1.get(0).toString());
-                       Assert.assertEquals("Mistmatch with expected value", 
"Str2", row1.get(1).toString());
-                       Assert.assertEquals("Mistmatch with expected value", 
"3.0", row1.get(2).toString());
-                       Assert.assertEquals("Mistmatch with expected value", 
"true", row1.get(3).toString());
+                       Assert.assertEquals("Mistmatch with expected value", 
Long.valueOf(1), row1.get(0));
+                       Assert.assertEquals("Mistmatch with expected value", 
"Str2", row1.get(1));
+                       Assert.assertEquals("Mistmatch with expected value", 
3.0, row1.get(2));
+                       Assert.assertEquals("Mistmatch with expected value", 
true, row1.get(3));
                        
                        Row row2 = listAOut.get(1);
-                       Assert.assertEquals("Mistmatch with expected value", 
"4", row2.get(0).toString());
-                       Assert.assertEquals("Mistmatch with expected value", 
"Str12", row2.get(1).toString());
-                       Assert.assertEquals("Mistmatch with expected value", 
"13.0", row2.get(2).toString());
-                       Assert.assertEquals("Mistmatch with expected value", 
"true", row2.get(3).toString());
+                       Assert.assertEquals("Mistmatch with expected value", 
Long.valueOf(4), row2.get(0));
+                       Assert.assertEquals("Mistmatch with expected value", 
"Str12", row2.get(1));
+                       Assert.assertEquals("Mistmatch with expected value", 
13.0, row2.get(2));
+                       Assert.assertEquals("Mistmatch with expected value", 
true, row2.get(3));
 
                        DataFrame dataFrameC = 
mlResults.getDataFrame("C").drop(RDDConverterUtils.DF_ID_COLUMN);
+                       StructType dfschemaC = dataFrameC.schema(); 
+                       StructField structTypeC = dfschemaC.apply(0);
+                       Assert.assertEquals(DataTypes.StringType, 
structTypeC.dataType());
+                       structTypeC = dfschemaC.apply(1);
+                       Assert.assertEquals(DataTypes.DoubleType, 
structTypeC.dataType());
+                       
                        List<Row> listCOut = dataFrameC.collectAsList();
 
                        Row row3 = listCOut.get(0);
-                       Assert.assertEquals("Mistmatch with expected value", 
"Str12", row3.get(0).toString());
-                       Assert.assertEquals("Mistmatch with expected value", 
"13.0", row3.get(1).toString());
+                       Assert.assertEquals("Mistmatch with expected value", 
"Str12", row3.get(0));
+                       Assert.assertEquals("Mistmatch with expected value", 
13.0, row3.get(1));
 
                        Row row4 = listCOut.get(1);
                        Assert.assertEquals("Mistmatch with expected value", 
"Str25", row4.get(0));
-                       Assert.assertEquals("Mistmatch with expected value", 
"26.0", row4.get(1));
+                       Assert.assertEquals("Mistmatch with expected value", 
26.0, row4.get(1));
                } else {
                        String[][] frameA = 
mlResults.getFrameAs2DStringArray("A");
                        Assert.assertEquals("Str2", frameA[0][1]);
@@ -485,6 +520,47 @@ public class MLContextFrameTest extends AutomatedTestBase {
                Assert.assertEquals(18.0, matrix[2][0], 0.0);
        }
 
+       @Test
+       public void testInputFrameAndMatrixOutputMatrixAndFrame() {
+               System.out.println("MLContextFrameTest - input frame and 
matrix, output matrix and frame");
+               
+               Row[] rowsA = {RowFactory.create("Doc1", "Feat1", 10), 
RowFactory.create("Doc1", "Feat2", 20), RowFactory.create("Doc2", "Feat1", 31)};
+
+               JavaRDD<Row> javaRddRowA = sc. parallelize( 
Arrays.asList(rowsA)); 
+
+               SQLContext sqlContext = new SQLContext(sc);
+
+               List<StructField> fieldsA = new ArrayList<StructField>();
+               fieldsA.add(DataTypes.createStructField("myID", 
DataTypes.StringType, true));
+               fieldsA.add(DataTypes.createStructField("FeatureName", 
DataTypes.StringType, true));
+               fieldsA.add(DataTypes.createStructField("FeatureValue", 
DataTypes.IntegerType, true));
+               StructType schemaA = DataTypes.createStructType(fieldsA);
+               DataFrame dataFrameA = sqlContext.createDataFrame(javaRddRowA, 
schemaA);
+
+               String dmlString = "[tA, tAM] = transformencode (target = A, 
spec = \"{ids: false ,recode: [ myID, FeatureName ]}\");";
+
+               Script script = dml(dmlString)
+                               .in("A", dataFrameA,
+                                               new 
FrameMetadata(FrameFormat.CSV, dataFrameA.count(), (long) 
dataFrameA.columns().length))
+                               .out("tA").out("tAM");
+               MLResults results = ml.execute(script);
+
+               double[][] matrixtA = results.getMatrixAs2DDoubleArray("tA");
+               Assert.assertEquals(10.0, matrixtA[0][2], 0.0);
+               Assert.assertEquals(20.0, matrixtA[1][2], 0.0);
+               Assert.assertEquals(31.0, matrixtA[2][2], 0.0);
+
+               DataFrame dataFrame_tA = results.getMatrix("tA").toDF();
+               System.out.println("Number of matrix tA rows = " + 
dataFrame_tA.count());
+               dataFrame_tA.printSchema();
+               dataFrame_tA.show();
+               
+               DataFrame dataFrame_tAM = results.getFrame("tAM").toDF();
+               System.out.println("Number of frame tAM rows = " + 
dataFrame_tAM.count());
+               dataFrame_tAM.printSchema();
+               dataFrame_tAM.show();
+       }
+
        // NOTE: the ordering of the frame values seem to come out differently 
here
        // than in the scala shell,
        // so this should be investigated or explained.

Reply via email to