[SYSTEMML-804] Size propagation frame transform functions, recompile Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/a39aecff Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/a39aecff Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/a39aecff
Branch: refs/heads/master Commit: a39aecffa0868853b2c60ce412470b7074e0dd53 Parents: c7beb50 Author: Matthias Boehm <[email protected]> Authored: Mon Jul 11 22:38:59 2016 -0700 Committer: Matthias Boehm <[email protected]> Committed: Tue Jul 12 11:31:57 2016 -0700 ---------------------------------------------------------------------- .../sysml/hops/ParameterizedBuiltinOp.java | 31 ++++++++-- .../apache/sysml/hops/recompile/Recompiler.java | 8 ++- .../controlprogram/caching/FrameObject.java | 1 + .../context/SparkExecutionContext.java | 4 +- .../TransformFrameEncodeApplyTest.java | 61 +++++++++++++++++++- .../TransformFrameEncodeDecodeTest.java | 27 +++++++++ 6 files changed, 123 insertions(+), 9 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/a39aecff/src/main/java/org/apache/sysml/hops/ParameterizedBuiltinOp.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/ParameterizedBuiltinOp.java b/src/main/java/org/apache/sysml/hops/ParameterizedBuiltinOp.java index f1ca98c..b3aec91 100644 --- a/src/main/java/org/apache/sysml/hops/ParameterizedBuiltinOp.java +++ b/src/main/java/org/apache/sysml/hops/ParameterizedBuiltinOp.java @@ -1062,8 +1062,7 @@ public class ParameterizedBuiltinOp extends Hop implements MultiThreadedHop Hop dir = getInput().get(_paramIndexMap.get("dir")); double maxVal = HopRewriteUtils.getDoubleValueSafe((LiteralOp)max); String dirVal = ((LiteralOp)dir).getStringValue(); - if( mc.dimsKnown() ) - { + if( mc.dimsKnown() ) { long lnnz = mc.nnzKnown() ? mc.getNonZeros() : mc.getRows(); if( "cols".equals(dirVal) ) { //expand horizontally ret = new long[]{mc.getRows(), UtilFunctions.toLong(maxVal), lnnz}; @@ -1073,6 +1072,20 @@ public class ParameterizedBuiltinOp extends Hop implements MultiThreadedHop } } } + else if( _op == ParamBuiltinOp.TRANSFORMDECODE ) { + if( mc.dimsKnown() ) { + //rows: remain unchanged + //cols: dummy coding might decrease never increase cols + return new long[]{mc.getRows(), mc.getCols(), mc.getRows()*mc.getCols()}; + } + } + else if( _op == ParamBuiltinOp.TRANSFORMAPPLY ) { + if( mc.dimsKnown() ) { + //rows: omitting might decrease but never increase rows + //cols: dummy coding and binning might increase cols but nnz stays constant + return new long[]{mc.getRows(), mc.getCols(), mc.getRows()*mc.getCols()}; + } + } return ret; } @@ -1205,11 +1218,21 @@ public class ParameterizedBuiltinOp extends Hop implements MultiThreadedHop break; } - case TRANSFORMAPPLY: { + case TRANSFORMDECODE: { Hop target = getInput().get(_paramIndexMap.get("target")); - setDim1( target.getDim1() ); //rows remain unchanged + //rows remain unchanged for recoding and dummy coding + setDim1( target.getDim1() ); + //cols remain unchanged only if no dummy coding + //TODO parse json spec + break; } + + case TRANSFORMAPPLY: { + //rows remain unchanged only if no omitting + //cols remain unchanged of no dummy coding + //TODO parse json spec break; + } default: //do nothing break; http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/a39aecff/src/main/java/org/apache/sysml/hops/recompile/Recompiler.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/recompile/Recompiler.java b/src/main/java/org/apache/sysml/hops/recompile/Recompiler.java index 5e65bf1..f7204e8 100644 --- a/src/main/java/org/apache/sysml/hops/recompile/Recompiler.java +++ b/src/main/java/org/apache/sysml/hops/recompile/Recompiler.java @@ -1638,13 +1638,17 @@ public class Recompiler if( vars.keySet().contains( varName ) ) { Data dat = vars.get(varName); - if( dat instanceof MatrixObject ) - { + if( dat instanceof MatrixObject ) { MatrixObject mo = (MatrixObject) dat; d.setDim1(mo.getNumRows()); d.setDim2(mo.getNumColumns()); d.setNnz(mo.getNnz()); } + else if( dat instanceof FrameObject ) { + FrameObject fo = (FrameObject) dat; + d.setDim1(fo.getNumRows()); + d.setDim2(fo.getNumColumns()); + } } } //special case for persistent reads with unknown size (read-after-write) http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/a39aecff/src/main/java/org/apache/sysml/runtime/controlprogram/caching/FrameObject.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/caching/FrameObject.java b/src/main/java/org/apache/sysml/runtime/controlprogram/caching/FrameObject.java index db98a3e..e3d2332 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/caching/FrameObject.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/caching/FrameObject.java @@ -121,6 +121,7 @@ public class FrameObject extends CacheableData<FrameBlock> //update matrix characteristics MatrixCharacteristics mc = ((MatrixDimensionsMetaData) _metaData).getMatrixCharacteristics(); mc.setDimension( _data.getNumRows(),_data.getNumColumns() ); + mc.setNonZeros(_data.getNumRows()*_data.getNumColumns()); //update schema information _schema = _data.getSchema(); http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/a39aecff/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java b/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java index 58027ce..99614f2 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java @@ -979,8 +979,10 @@ public class SparkExecutionContext extends ExecutionContext //copy into output frame out.copy( ix, ix+block.getNumRows()-1, 0, block.getNumColumns()-1, block ); - if( ix == 0 ) + if( ix == 0 ) { + out.setColumnNames(block.getColumnNames()); out.setColumnMetadata(block.getColumnMetadata()); + } } if (DMLScript.STATISTICS) { http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/a39aecff/src/test/java/org/apache/sysml/test/integration/functions/transform/TransformFrameEncodeApplyTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/transform/TransformFrameEncodeApplyTest.java b/src/test/java/org/apache/sysml/test/integration/functions/transform/TransformFrameEncodeApplyTest.java index 27d58f9..2d17c17 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/transform/TransformFrameEncodeApplyTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/transform/TransformFrameEncodeApplyTest.java @@ -19,6 +19,7 @@ package org.apache.sysml.test.integration.functions.transform; +import org.junit.Assert; import org.junit.Test; import org.apache.sysml.api.DMLScript; import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM; @@ -29,6 +30,7 @@ import org.apache.sysml.runtime.util.DataConverter; import org.apache.sysml.test.integration.AutomatedTestBase; import org.apache.sysml.test.integration.TestConfiguration; import org.apache.sysml.test.utils.TestUtils; +import org.apache.sysml.utils.Statistics; public class TransformFrameEncodeApplyTest extends AutomatedTestBase { @@ -77,6 +79,11 @@ public class TransformFrameEncodeApplyTest extends AutomatedTestBase } @Test + public void testHomesRecodeIDsHybridCSV() { + runTransformTest(RUNTIME_PLATFORM.HYBRID_SPARK, "csv", TransformType.RECODE, false); + } + + @Test public void testHomesDummycodeIDsSingleNodeCSV() { runTransformTest(RUNTIME_PLATFORM.SINGLE_NODE, "csv", TransformType.DUMMY, false); } @@ -87,6 +94,11 @@ public class TransformFrameEncodeApplyTest extends AutomatedTestBase } @Test + public void testHomesDummycodeIDsHybridCSV() { + runTransformTest(RUNTIME_PLATFORM.HYBRID_SPARK, "csv", TransformType.DUMMY, false); + } + + @Test public void testHomesBinningIDsSingleNodeCSV() { runTransformTest(RUNTIME_PLATFORM.SINGLE_NODE, "csv", TransformType.BIN, false); } @@ -97,6 +109,11 @@ public class TransformFrameEncodeApplyTest extends AutomatedTestBase } @Test + public void testHomesBinningIDsHybridCSV() { + runTransformTest(RUNTIME_PLATFORM.HYBRID_SPARK, "csv", TransformType.BIN, false); + } + + @Test public void testHomesOmitIDsSingleNodeCSV() { runTransformTest(RUNTIME_PLATFORM.SINGLE_NODE, "csv", TransformType.OMIT, false); } @@ -107,6 +124,11 @@ public class TransformFrameEncodeApplyTest extends AutomatedTestBase } @Test + public void testHomesOmitIDsHybridCSV() { + runTransformTest(RUNTIME_PLATFORM.HYBRID_SPARK, "csv", TransformType.OMIT, false); + } + + @Test public void testHomesImputeIDsSingleNodeCSV() { runTransformTest(RUNTIME_PLATFORM.SINGLE_NODE, "csv", TransformType.IMPUTE, false); } @@ -115,6 +137,11 @@ public class TransformFrameEncodeApplyTest extends AutomatedTestBase public void testHomesImputeIDsSparkCSV() { runTransformTest(RUNTIME_PLATFORM.SPARK, "csv", TransformType.IMPUTE, false); } + + @Test + public void testHomesImputeIDsHybridCSV() { + runTransformTest(RUNTIME_PLATFORM.HYBRID_SPARK, "csv", TransformType.IMPUTE, false); + } @Test public void testHomesRecodeColnamesSingleNodeCSV() { @@ -127,6 +154,11 @@ public class TransformFrameEncodeApplyTest extends AutomatedTestBase } @Test + public void testHomesRecodeColnamesHybridCSV() { + runTransformTest(RUNTIME_PLATFORM.HYBRID_SPARK, "csv", TransformType.RECODE, true); + } + + @Test public void testHomesDummycodeColnamesSingleNodeCSV() { runTransformTest(RUNTIME_PLATFORM.SINGLE_NODE, "csv", TransformType.DUMMY, true); } @@ -137,6 +169,11 @@ public class TransformFrameEncodeApplyTest extends AutomatedTestBase } @Test + public void testHomesDummycodeColnamesHybridCSV() { + runTransformTest(RUNTIME_PLATFORM.HYBRID_SPARK, "csv", TransformType.DUMMY, true); + } + + @Test public void testHomesBinningColnamesSingleNodeCSV() { runTransformTest(RUNTIME_PLATFORM.SINGLE_NODE, "csv", TransformType.BIN, true); } @@ -147,6 +184,11 @@ public class TransformFrameEncodeApplyTest extends AutomatedTestBase } @Test + public void testHomesBinningColnamesHybridCSV() { + runTransformTest(RUNTIME_PLATFORM.HYBRID_SPARK, "csv", TransformType.BIN, true); + } + + @Test public void testHomesOmitColnamesSingleNodeCSV() { runTransformTest(RUNTIME_PLATFORM.SINGLE_NODE, "csv", TransformType.OMIT, true); } @@ -157,6 +199,11 @@ public class TransformFrameEncodeApplyTest extends AutomatedTestBase } @Test + public void testHomesOmitvColnamesHybridCSV() { + runTransformTest(RUNTIME_PLATFORM.HYBRID_SPARK, "csv", TransformType.OMIT, true); + } + + @Test public void testHomesImputeColnamesSingleNodeCSV() { runTransformTest(RUNTIME_PLATFORM.SINGLE_NODE, "csv", TransformType.IMPUTE, true); } @@ -166,6 +213,11 @@ public class TransformFrameEncodeApplyTest extends AutomatedTestBase runTransformTest(RUNTIME_PLATFORM.SPARK, "csv", TransformType.IMPUTE, true); } + @Test + public void testHomesImputeColnamesHybridCSV() { + runTransformTest(RUNTIME_PLATFORM.HYBRID_SPARK, "csv", TransformType.IMPUTE, true); + } + /** * * @param rt @@ -202,7 +254,7 @@ public class TransformFrameEncodeApplyTest extends AutomatedTestBase String HOME = SCRIPT_DIR + TEST_DIR; fullDMLScriptName = HOME + TEST_NAME1 + ".dml"; - programArgs = new String[]{"-explain","-nvargs", + programArgs = new String[]{"-explain", "recompile_hops", "-nvargs", "DATA=" + HOME + "input/" + DATASET, "TFSPEC=" + HOME + "input/" + SPEC, "TFDATA1=" + output("tfout1"), @@ -219,7 +271,12 @@ public class TransformFrameEncodeApplyTest extends AutomatedTestBase double[][] R2 = DataConverter.convertToDoubleMatrix(MatrixReaderFactory .createMatrixReader(InputInfo.CSVInputInfo) .readMatrixFromHDFS(output("tfout2"), -1L, -1L, 1000, 1000, -1)); - TestUtils.compareMatrices(R1, R2, R1.length, R1[0].length, 0); + TestUtils.compareMatrices(R1, R2, R1.length, R1[0].length, 0); + + if( rt == RUNTIME_PLATFORM.HYBRID_SPARK ) { + Assert.assertEquals("Wrong number of executed Spark instructions: " + + Statistics.getNoOfExecutedSPInst(), new Long(2), new Long(Statistics.getNoOfExecutedSPInst())); + } } catch(Exception ex) { throw new RuntimeException(ex); http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/a39aecff/src/test/java/org/apache/sysml/test/integration/functions/transform/TransformFrameEncodeDecodeTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/transform/TransformFrameEncodeDecodeTest.java b/src/test/java/org/apache/sysml/test/integration/functions/transform/TransformFrameEncodeDecodeTest.java index 0bdf4da..b676989 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/transform/TransformFrameEncodeDecodeTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/transform/TransformFrameEncodeDecodeTest.java @@ -19,6 +19,7 @@ package org.apache.sysml.test.integration.functions.transform; +import org.junit.Assert; import org.junit.Test; import org.apache.sysml.api.DMLScript; import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM; @@ -32,6 +33,7 @@ import org.apache.sysml.runtime.util.DataConverter; import org.apache.sysml.test.integration.AutomatedTestBase; import org.apache.sysml.test.integration.TestConfiguration; import org.apache.sysml.test.utils.TestUtils; +import org.apache.sysml.utils.Statistics; public class TransformFrameEncodeDecodeTest extends AutomatedTestBase { @@ -72,6 +74,11 @@ public class TransformFrameEncodeDecodeTest extends AutomatedTestBase } @Test + public void testHomesRecodeIDsHybridCSV() { + runTransformTest(RUNTIME_PLATFORM.HYBRID_SPARK, "csv", TransformType.RECODE, false); + } + + @Test public void testHomesDummycodeIDsSingleNodeCSV() { runTransformTest(RUNTIME_PLATFORM.SINGLE_NODE, "csv", TransformType.DUMMY, false); } @@ -82,6 +89,11 @@ public class TransformFrameEncodeDecodeTest extends AutomatedTestBase } @Test + public void testHomesDummycodeIDsHybridCSV() { + runTransformTest(RUNTIME_PLATFORM.HYBRID_SPARK, "csv", TransformType.DUMMY, false); + } + + @Test public void testHomesRecodeColnamesSingleNodeCSV() { runTransformTest(RUNTIME_PLATFORM.SINGLE_NODE, "csv", TransformType.RECODE, true); } @@ -92,6 +104,11 @@ public class TransformFrameEncodeDecodeTest extends AutomatedTestBase } @Test + public void testHomesRecodeColnamesHybridCSV() { + runTransformTest(RUNTIME_PLATFORM.HYBRID_SPARK, "csv", TransformType.RECODE, true); + } + + @Test public void testHomesDummycodeColnamesSingleNodeCSV() { runTransformTest(RUNTIME_PLATFORM.SINGLE_NODE, "csv", TransformType.DUMMY, true); } @@ -101,6 +118,11 @@ public class TransformFrameEncodeDecodeTest extends AutomatedTestBase runTransformTest(RUNTIME_PLATFORM.SPARK, "csv", TransformType.DUMMY, true); } + @Test + public void testHomesDummycodeColnamesHybridCSV() { + runTransformTest(RUNTIME_PLATFORM.HYBRID_SPARK, "csv", TransformType.DUMMY, true); + } + /** * * @param rt @@ -153,6 +175,11 @@ public class TransformFrameEncodeDecodeTest extends AutomatedTestBase String[][] R1 = DataConverter.convertToStringFrame(fb1); String[][] R2 = DataConverter.convertToStringFrame(fb2); TestUtils.compareFrames(R1, R2, R1.length, R1[0].length); + + if( rt == RUNTIME_PLATFORM.HYBRID_SPARK ) { + Assert.assertEquals("Wrong number of executed Spark instructions: " + + Statistics.getNoOfExecutedSPInst(), new Long(2), new Long(Statistics.getNoOfExecutedSPInst())); + } } catch(Exception ex) { throw new RuntimeException(ex);
