[SYSTEMML-804] Size propagation frame transform functions, recompile

Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: 
http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/a39aecff
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/a39aecff
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/a39aecff

Branch: refs/heads/master
Commit: a39aecffa0868853b2c60ce412470b7074e0dd53
Parents: c7beb50
Author: Matthias Boehm <[email protected]>
Authored: Mon Jul 11 22:38:59 2016 -0700
Committer: Matthias Boehm <[email protected]>
Committed: Tue Jul 12 11:31:57 2016 -0700

----------------------------------------------------------------------
 .../sysml/hops/ParameterizedBuiltinOp.java      | 31 ++++++++--
 .../apache/sysml/hops/recompile/Recompiler.java |  8 ++-
 .../controlprogram/caching/FrameObject.java     |  1 +
 .../context/SparkExecutionContext.java          |  4 +-
 .../TransformFrameEncodeApplyTest.java          | 61 +++++++++++++++++++-
 .../TransformFrameEncodeDecodeTest.java         | 27 +++++++++
 6 files changed, 123 insertions(+), 9 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/a39aecff/src/main/java/org/apache/sysml/hops/ParameterizedBuiltinOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/ParameterizedBuiltinOp.java 
b/src/main/java/org/apache/sysml/hops/ParameterizedBuiltinOp.java
index f1ca98c..b3aec91 100644
--- a/src/main/java/org/apache/sysml/hops/ParameterizedBuiltinOp.java
+++ b/src/main/java/org/apache/sysml/hops/ParameterizedBuiltinOp.java
@@ -1062,8 +1062,7 @@ public class ParameterizedBuiltinOp extends Hop 
implements MultiThreadedHop
                        Hop dir = getInput().get(_paramIndexMap.get("dir"));
                        double maxVal = 
HopRewriteUtils.getDoubleValueSafe((LiteralOp)max);
                        String dirVal = ((LiteralOp)dir).getStringValue();
-                       if( mc.dimsKnown() )
-                       {
+                       if( mc.dimsKnown() ) {
                                long lnnz = mc.nnzKnown() ? mc.getNonZeros() : 
mc.getRows();
                                if( "cols".equals(dirVal) ) { //expand 
horizontally
                                        ret = new long[]{mc.getRows(), 
UtilFunctions.toLong(maxVal), lnnz};
@@ -1073,6 +1072,20 @@ public class ParameterizedBuiltinOp extends Hop 
implements MultiThreadedHop
                                }       
                        }
                }
+               else if( _op == ParamBuiltinOp.TRANSFORMDECODE ) {
+                       if( mc.dimsKnown() ) {
+                               //rows: remain unchanged
+                               //cols: dummy coding might decrease never 
increase cols 
+                               return new long[]{mc.getRows(), mc.getCols(), 
mc.getRows()*mc.getCols()};
+                       }
+               }
+               else if( _op == ParamBuiltinOp.TRANSFORMAPPLY ) {
+                       if( mc.dimsKnown() ) {
+                               //rows: omitting might decrease but never 
increase rows
+                               //cols: dummy coding and binning might increase 
cols but nnz stays constant
+                               return new long[]{mc.getRows(), mc.getCols(), 
mc.getRows()*mc.getCols()};
+                       }
+               }
                
                return ret;
        }
@@ -1205,11 +1218,21 @@ public class ParameterizedBuiltinOp extends Hop 
implements MultiThreadedHop
                                
                                break;  
                        }
-                       case TRANSFORMAPPLY: {
+                       case TRANSFORMDECODE: {
                                Hop target = 
getInput().get(_paramIndexMap.get("target"));
-                               setDim1( target.getDim1() ); //rows remain 
unchanged
+                               //rows remain unchanged for recoding and dummy 
coding
+                               setDim1( target.getDim1() );
+                               //cols remain unchanged only if no dummy coding
+                               //TODO parse json spec
+                               break;
                        }
+                       
+                       case TRANSFORMAPPLY: {
+                               //rows remain unchanged only if no omitting
+                               //cols remain unchanged of no dummy coding 
+                               //TODO parse json spec
                                break;
+                       }
                        default:
                                //do nothing
                                break;

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/a39aecff/src/main/java/org/apache/sysml/hops/recompile/Recompiler.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/recompile/Recompiler.java 
b/src/main/java/org/apache/sysml/hops/recompile/Recompiler.java
index 5e65bf1..f7204e8 100644
--- a/src/main/java/org/apache/sysml/hops/recompile/Recompiler.java
+++ b/src/main/java/org/apache/sysml/hops/recompile/Recompiler.java
@@ -1638,13 +1638,17 @@ public class Recompiler
                        if( vars.keySet().contains( varName ) )
                        {
                                Data dat = vars.get(varName);
-                               if( dat instanceof MatrixObject )
-                               {
+                               if( dat instanceof MatrixObject ) {
                                        MatrixObject mo = (MatrixObject) dat;
                                        d.setDim1(mo.getNumRows());
                                        d.setDim2(mo.getNumColumns());
                                        d.setNnz(mo.getNnz());
                                }
+                               else if( dat instanceof FrameObject ) {
+                                       FrameObject fo = (FrameObject) dat;
+                                       d.setDim1(fo.getNumRows());
+                                       d.setDim2(fo.getNumColumns());
+                               }
                        }
                }
                //special case for persistent reads with unknown size 
(read-after-write)

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/a39aecff/src/main/java/org/apache/sysml/runtime/controlprogram/caching/FrameObject.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/caching/FrameObject.java
 
b/src/main/java/org/apache/sysml/runtime/controlprogram/caching/FrameObject.java
index db98a3e..e3d2332 100644
--- 
a/src/main/java/org/apache/sysml/runtime/controlprogram/caching/FrameObject.java
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/caching/FrameObject.java
@@ -121,6 +121,7 @@ public class FrameObject extends CacheableData<FrameBlock>
                //update matrix characteristics
                MatrixCharacteristics mc = ((MatrixDimensionsMetaData) 
_metaData).getMatrixCharacteristics();
                mc.setDimension( _data.getNumRows(),_data.getNumColumns() );
+               mc.setNonZeros(_data.getNumRows()*_data.getNumColumns());
                
                //update schema information
                _schema = _data.getSchema();

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/a39aecff/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java
 
b/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java
index 58027ce..99614f2 100644
--- 
a/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java
@@ -979,8 +979,10 @@ public class SparkExecutionContext extends ExecutionContext
                
                        //copy into output frame
                        out.copy( ix, ix+block.getNumRows()-1, 0, 
block.getNumColumns()-1, block );
-                       if( ix == 0 )
+                       if( ix == 0 ) {
+                               out.setColumnNames(block.getColumnNames());
                                
out.setColumnMetadata(block.getColumnMetadata());
+                       }
                }
                
                if (DMLScript.STATISTICS) {

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/a39aecff/src/test/java/org/apache/sysml/test/integration/functions/transform/TransformFrameEncodeApplyTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/functions/transform/TransformFrameEncodeApplyTest.java
 
b/src/test/java/org/apache/sysml/test/integration/functions/transform/TransformFrameEncodeApplyTest.java
index 27d58f9..2d17c17 100644
--- 
a/src/test/java/org/apache/sysml/test/integration/functions/transform/TransformFrameEncodeApplyTest.java
+++ 
b/src/test/java/org/apache/sysml/test/integration/functions/transform/TransformFrameEncodeApplyTest.java
@@ -19,6 +19,7 @@
 
 package org.apache.sysml.test.integration.functions.transform;
 
+import org.junit.Assert;
 import org.junit.Test;
 import org.apache.sysml.api.DMLScript;
 import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM;
@@ -29,6 +30,7 @@ import org.apache.sysml.runtime.util.DataConverter;
 import org.apache.sysml.test.integration.AutomatedTestBase;
 import org.apache.sysml.test.integration.TestConfiguration;
 import org.apache.sysml.test.utils.TestUtils;
+import org.apache.sysml.utils.Statistics;
 
 public class TransformFrameEncodeApplyTest extends AutomatedTestBase 
 {
@@ -77,6 +79,11 @@ public class TransformFrameEncodeApplyTest extends 
AutomatedTestBase
        }
        
        @Test
+       public void testHomesRecodeIDsHybridCSV() {
+               runTransformTest(RUNTIME_PLATFORM.HYBRID_SPARK, "csv", 
TransformType.RECODE, false);
+       }
+       
+       @Test
        public void testHomesDummycodeIDsSingleNodeCSV() {
                runTransformTest(RUNTIME_PLATFORM.SINGLE_NODE, "csv", 
TransformType.DUMMY, false);
        }
@@ -87,6 +94,11 @@ public class TransformFrameEncodeApplyTest extends 
AutomatedTestBase
        }
        
        @Test
+       public void testHomesDummycodeIDsHybridCSV() {
+               runTransformTest(RUNTIME_PLATFORM.HYBRID_SPARK, "csv", 
TransformType.DUMMY, false);
+       }
+       
+       @Test
        public void testHomesBinningIDsSingleNodeCSV() {
                runTransformTest(RUNTIME_PLATFORM.SINGLE_NODE, "csv", 
TransformType.BIN, false);
        }
@@ -97,6 +109,11 @@ public class TransformFrameEncodeApplyTest extends 
AutomatedTestBase
        }
        
        @Test
+       public void testHomesBinningIDsHybridCSV() {
+               runTransformTest(RUNTIME_PLATFORM.HYBRID_SPARK, "csv", 
TransformType.BIN, false);
+       }
+       
+       @Test
        public void testHomesOmitIDsSingleNodeCSV() {
                runTransformTest(RUNTIME_PLATFORM.SINGLE_NODE, "csv", 
TransformType.OMIT, false);
        }
@@ -107,6 +124,11 @@ public class TransformFrameEncodeApplyTest extends 
AutomatedTestBase
        }
        
        @Test
+       public void testHomesOmitIDsHybridCSV() {
+               runTransformTest(RUNTIME_PLATFORM.HYBRID_SPARK, "csv", 
TransformType.OMIT, false);
+       }
+       
+       @Test
        public void testHomesImputeIDsSingleNodeCSV() {
                runTransformTest(RUNTIME_PLATFORM.SINGLE_NODE, "csv", 
TransformType.IMPUTE, false);
        }
@@ -115,6 +137,11 @@ public class TransformFrameEncodeApplyTest extends 
AutomatedTestBase
        public void testHomesImputeIDsSparkCSV() {
                runTransformTest(RUNTIME_PLATFORM.SPARK, "csv", 
TransformType.IMPUTE, false);
        }
+       
+       @Test
+       public void testHomesImputeIDsHybridCSV() {
+               runTransformTest(RUNTIME_PLATFORM.HYBRID_SPARK, "csv", 
TransformType.IMPUTE, false);
+       }
 
        @Test
        public void testHomesRecodeColnamesSingleNodeCSV() {
@@ -127,6 +154,11 @@ public class TransformFrameEncodeApplyTest extends 
AutomatedTestBase
        }
        
        @Test
+       public void testHomesRecodeColnamesHybridCSV() {
+               runTransformTest(RUNTIME_PLATFORM.HYBRID_SPARK, "csv", 
TransformType.RECODE, true);
+       }
+       
+       @Test
        public void testHomesDummycodeColnamesSingleNodeCSV() {
                runTransformTest(RUNTIME_PLATFORM.SINGLE_NODE, "csv", 
TransformType.DUMMY, true);
        }
@@ -137,6 +169,11 @@ public class TransformFrameEncodeApplyTest extends 
AutomatedTestBase
        }
        
        @Test
+       public void testHomesDummycodeColnamesHybridCSV() {
+               runTransformTest(RUNTIME_PLATFORM.HYBRID_SPARK, "csv", 
TransformType.DUMMY, true);
+       }
+       
+       @Test
        public void testHomesBinningColnamesSingleNodeCSV() {
                runTransformTest(RUNTIME_PLATFORM.SINGLE_NODE, "csv", 
TransformType.BIN, true);
        }
@@ -147,6 +184,11 @@ public class TransformFrameEncodeApplyTest extends 
AutomatedTestBase
        }
        
        @Test
+       public void testHomesBinningColnamesHybridCSV() {
+               runTransformTest(RUNTIME_PLATFORM.HYBRID_SPARK, "csv", 
TransformType.BIN, true);
+       }
+       
+       @Test
        public void testHomesOmitColnamesSingleNodeCSV() {
                runTransformTest(RUNTIME_PLATFORM.SINGLE_NODE, "csv", 
TransformType.OMIT, true);
        }
@@ -157,6 +199,11 @@ public class TransformFrameEncodeApplyTest extends 
AutomatedTestBase
        }
        
        @Test
+       public void testHomesOmitvColnamesHybridCSV() {
+               runTransformTest(RUNTIME_PLATFORM.HYBRID_SPARK, "csv", 
TransformType.OMIT, true);
+       }
+       
+       @Test
        public void testHomesImputeColnamesSingleNodeCSV() {
                runTransformTest(RUNTIME_PLATFORM.SINGLE_NODE, "csv", 
TransformType.IMPUTE, true);
        }
@@ -166,6 +213,11 @@ public class TransformFrameEncodeApplyTest extends 
AutomatedTestBase
                runTransformTest(RUNTIME_PLATFORM.SPARK, "csv", 
TransformType.IMPUTE, true);
        }
        
+       @Test
+       public void testHomesImputeColnamesHybridCSV() {
+               runTransformTest(RUNTIME_PLATFORM.HYBRID_SPARK, "csv", 
TransformType.IMPUTE, true);
+       }
+       
        /**
         * 
         * @param rt
@@ -202,7 +254,7 @@ public class TransformFrameEncodeApplyTest extends 
AutomatedTestBase
                        
                        String HOME = SCRIPT_DIR + TEST_DIR;
                        fullDMLScriptName = HOME + TEST_NAME1 + ".dml";
-                       programArgs = new String[]{"-explain","-nvargs", 
+                       programArgs = new String[]{"-explain", 
"recompile_hops", "-nvargs", 
                                "DATA=" + HOME + "input/" + DATASET,
                                "TFSPEC=" + HOME + "input/" + SPEC,
                                "TFDATA1=" + output("tfout1"),
@@ -219,7 +271,12 @@ public class TransformFrameEncodeApplyTest extends 
AutomatedTestBase
                        double[][] R2 = 
DataConverter.convertToDoubleMatrix(MatrixReaderFactory
                                .createMatrixReader(InputInfo.CSVInputInfo)
                                .readMatrixFromHDFS(output("tfout2"), -1L, -1L, 
1000, 1000, -1));
-                       TestUtils.compareMatrices(R1, R2, R1.length, 
R1[0].length, 0);                  
+                       TestUtils.compareMatrices(R1, R2, R1.length, 
R1[0].length, 0);          
+                       
+                       if( rt == RUNTIME_PLATFORM.HYBRID_SPARK ) {
+                               Assert.assertEquals("Wrong number of executed 
Spark instructions: " + 
+                                       Statistics.getNoOfExecutedSPInst(), new 
Long(2), new Long(Statistics.getNoOfExecutedSPInst()));
+                       }
                }
                catch(Exception ex) {
                        throw new RuntimeException(ex);

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/a39aecff/src/test/java/org/apache/sysml/test/integration/functions/transform/TransformFrameEncodeDecodeTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/functions/transform/TransformFrameEncodeDecodeTest.java
 
b/src/test/java/org/apache/sysml/test/integration/functions/transform/TransformFrameEncodeDecodeTest.java
index 0bdf4da..b676989 100644
--- 
a/src/test/java/org/apache/sysml/test/integration/functions/transform/TransformFrameEncodeDecodeTest.java
+++ 
b/src/test/java/org/apache/sysml/test/integration/functions/transform/TransformFrameEncodeDecodeTest.java
@@ -19,6 +19,7 @@
 
 package org.apache.sysml.test.integration.functions.transform;
 
+import org.junit.Assert;
 import org.junit.Test;
 import org.apache.sysml.api.DMLScript;
 import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM;
@@ -32,6 +33,7 @@ import org.apache.sysml.runtime.util.DataConverter;
 import org.apache.sysml.test.integration.AutomatedTestBase;
 import org.apache.sysml.test.integration.TestConfiguration;
 import org.apache.sysml.test.utils.TestUtils;
+import org.apache.sysml.utils.Statistics;
 
 public class TransformFrameEncodeDecodeTest extends AutomatedTestBase 
 {
@@ -72,6 +74,11 @@ public class TransformFrameEncodeDecodeTest extends 
AutomatedTestBase
        }
        
        @Test
+       public void testHomesRecodeIDsHybridCSV() {
+               runTransformTest(RUNTIME_PLATFORM.HYBRID_SPARK, "csv", 
TransformType.RECODE, false);
+       }
+       
+       @Test
        public void testHomesDummycodeIDsSingleNodeCSV() {
                runTransformTest(RUNTIME_PLATFORM.SINGLE_NODE, "csv", 
TransformType.DUMMY, false);
        }
@@ -82,6 +89,11 @@ public class TransformFrameEncodeDecodeTest extends 
AutomatedTestBase
        }
        
        @Test
+       public void testHomesDummycodeIDsHybridCSV() {
+               runTransformTest(RUNTIME_PLATFORM.HYBRID_SPARK, "csv", 
TransformType.DUMMY, false);
+       }
+       
+       @Test
        public void testHomesRecodeColnamesSingleNodeCSV() {
                runTransformTest(RUNTIME_PLATFORM.SINGLE_NODE, "csv", 
TransformType.RECODE, true);
        }
@@ -92,6 +104,11 @@ public class TransformFrameEncodeDecodeTest extends 
AutomatedTestBase
        }
        
        @Test
+       public void testHomesRecodeColnamesHybridCSV() {
+               runTransformTest(RUNTIME_PLATFORM.HYBRID_SPARK, "csv", 
TransformType.RECODE, true);
+       }
+       
+       @Test
        public void testHomesDummycodeColnamesSingleNodeCSV() {
                runTransformTest(RUNTIME_PLATFORM.SINGLE_NODE, "csv", 
TransformType.DUMMY, true);
        }
@@ -101,6 +118,11 @@ public class TransformFrameEncodeDecodeTest extends 
AutomatedTestBase
                runTransformTest(RUNTIME_PLATFORM.SPARK, "csv", 
TransformType.DUMMY, true);
        }
        
+       @Test
+       public void testHomesDummycodeColnamesHybridCSV() {
+               runTransformTest(RUNTIME_PLATFORM.HYBRID_SPARK, "csv", 
TransformType.DUMMY, true);
+       }
+       
        /**
         * 
         * @param rt
@@ -153,6 +175,11 @@ public class TransformFrameEncodeDecodeTest extends 
AutomatedTestBase
                        String[][] R1 = DataConverter.convertToStringFrame(fb1);
                        String[][] R2 = DataConverter.convertToStringFrame(fb2);
                        TestUtils.compareFrames(R1, R2, R1.length, 
R1[0].length);                       
+                       
+                       if( rt == RUNTIME_PLATFORM.HYBRID_SPARK ) {
+                               Assert.assertEquals("Wrong number of executed 
Spark instructions: " + 
+                                       Statistics.getNoOfExecutedSPInst(), new 
Long(2), new Long(Statistics.getNoOfExecutedSPInst()));
+                       }
                }
                catch(Exception ex) {
                        throw new RuntimeException(ex);

Reply via email to