[SYSTEMML-2077] Extended split-dag rewrite for new eval function This patch extends the existing rewrite for splitting DAGs after data-dependent operators to include the new second-order eval function as an interesting candidate. Eval allows the evaluation of unknown functions and thus always has unknown output sizes. The modified rewrite accordingly creates a hook for dynamic recompilation to adapt the subsequent runtime plan accordingly.
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/108ee7ae Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/108ee7ae Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/108ee7ae Branch: refs/heads/master Commit: 108ee7ae8098499248b5a91ab9378007d35509ae Parents: 2af8496 Author: Matthias Boehm <[email protected]> Authored: Thu Mar 8 23:11:53 2018 -0800 Committer: Matthias Boehm <[email protected]> Committed: Thu Mar 8 23:58:08 2018 -0800 ---------------------------------------------------------------------- .../RewriteSplitDagDataDependentOperators.java | 59 ++++++++++---------- .../integration/mlcontext/MLContextTest.java | 13 ++++- .../apache/sysml/api/mlcontext/eval2-test.dml | 37 ++++++++++++ 3 files changed, 78 insertions(+), 31 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/108ee7ae/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagDataDependentOperators.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagDataDependentOperators.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagDataDependentOperators.java index 7b4a733..ebf275b 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagDataDependentOperators.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagDataDependentOperators.java @@ -32,13 +32,13 @@ import org.apache.sysml.hops.DataOp; import org.apache.sysml.hops.Hop; import org.apache.sysml.hops.Hop.OpOp1; import org.apache.sysml.hops.Hop.OpOp3; +import org.apache.sysml.hops.Hop.OpOpN; import org.apache.sysml.hops.Hop.ParamBuiltinOp; import org.apache.sysml.hops.Hop.DataOpTypes; import org.apache.sysml.hops.Hop.ReOrgOp; import org.apache.sysml.hops.HopsException; import org.apache.sysml.hops.LiteralOp; import org.apache.sysml.hops.ParameterizedBuiltinOp; -import org.apache.sysml.hops.ReorgOp; import org.apache.sysml.hops.TernaryOp; import org.apache.sysml.hops.recompile.Recompiler; import org.apache.sysml.parser.DataIdentifier; @@ -242,11 +242,11 @@ public class RewriteSplitDagDataDependentOperators extends StatementBlockRewrite //collect data dependent operations (to be extended as necessary) //#1 removeEmpty - if( hop instanceof ParameterizedBuiltinOp + if( hop instanceof ParameterizedBuiltinOp && ((ParameterizedBuiltinOp) hop).getOp()==ParamBuiltinOp.RMEMPTY && !noSplitRequired && !(hop.getParent().size()==1 && hop.getParent().get(0) instanceof TernaryOp - && ((TernaryOp)hop.getParent().get(0)).isMatrixIgnoreZeroRewriteApplicable())) + && ((TernaryOp)hop.getParent().get(0)).isMatrixIgnoreZeroRewriteApplicable())) { ParameterizedBuiltinOp pbhop = (ParameterizedBuiltinOp)hop; cand.add(pbhop); @@ -268,23 +268,22 @@ public class RewriteSplitDagDataDependentOperators extends StatementBlockRewrite //configure rmEmpty to directly output selection vector //(only applied if dynamic recompilation enabled) - if( ConfigurationManager.isDynamicRecompilation() ) + if( ConfigurationManager.isDynamicRecompilation() ) pbhop.setOutputPermutationMatrix(true); for( Hop p : hop.getParent() ) - ((AggBinaryOp)p).setHasLeftPMInput(true); + ((AggBinaryOp)p).setHasLeftPMInput(true); } } //#2 ctable with unknown dims - if( hop instanceof TernaryOp - && ((TernaryOp) hop).getOp()==OpOp3.CTABLE + if( HopRewriteUtils.isTernary(hop, OpOp3.CTABLE) && hop.getInput().size() < 4 //dims not provided && !noSplitRequired ) { cand.add(hop); investigateChilds = false; - //keep interesting consumer information, flag hops accordingly + //keep interesting consumer information, flag hops accordingly boolean onlyPMM = true; for( Hop p : hop.getParent() ) { onlyPMM &= (p instanceof AggBinaryOp && hop == p.getInput().get(0)); @@ -293,29 +292,31 @@ public class RewriteSplitDagDataDependentOperators extends StatementBlockRewrite if( onlyPMM && HopRewriteUtils.isBasic1NSequence(hop.getInput().get(0)) ) hop.setOutputEmptyBlocks(false); } - - //#3 orderby childs computed in same DAG - if( hop instanceof ReorgOp - && ((ReorgOp)hop).getOp()==ReOrgOp.SORT ) - { - //params 'decreasing' / 'indexreturn' - for( int i=2; i<=3; i++ ) { - Hop c = hop.getInput().get(i); - if( !(c instanceof LiteralOp || c instanceof DataOp) ){ - cand.add(c); - c.setVisited(); - investigateChilds = false; - } - - } - } + + //#3 orderby childs computed in same DAG + if( HopRewriteUtils.isReorg(hop, ReOrgOp.SORT) ){ + //params 'decreasing' / 'indexreturn' + for( int i=2; i<=3; i++ ) { + Hop c = hop.getInput().get(i); + if( !(c instanceof LiteralOp || c instanceof DataOp) ){ + cand.add(c); + c.setVisited(); + investigateChilds = false; + } + } + } + + //#4 second-order eval function + if( HopRewriteUtils.isNary(hop, OpOpN.EVAL) && !noSplitRequired ) { + cand.add(hop); + investigateChilds = false; + } //process children (if not already found a special operators; - //otherwise, processed by recursive rule application) - if( investigateChilds ) - if( hop.getInput()!=null ) - for( Hop c : hop.getInput() ) - rCollectDataDependentOperators(c, cand); + //otherwise, processed by recursive rule application) + if( investigateChilds && hop.getInput()!=null ) + for( Hop c : hop.getInput() ) + rCollectDataDependentOperators(c, cand); hop.setVisited(); } http://git-wip-us.apache.org/repos/asf/systemml/blob/108ee7ae/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTest.java b/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTest.java index e70faa9..a6b2ea2 100644 --- a/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTest.java +++ b/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTest.java @@ -75,6 +75,7 @@ import org.apache.sysml.runtime.matrix.MatrixCharacteristics; import org.apache.sysml.runtime.matrix.data.MatrixBlock; import org.apache.sysml.runtime.matrix.data.MatrixIndexes; import org.apache.sysml.runtime.util.DataConverter; +import org.apache.sysml.utils.Statistics; import org.junit.Assert; import org.junit.Test; @@ -89,12 +90,20 @@ import scala.collection.Seq; public class MLContextTest extends MLContextTestBase { @Test - public void testCreateDMLScriptBasedOnFileAndExecuteEvalTest() { - System.out.println("MLContextTest - create DML script based on file and execute"); + public void testBasicExecuteEvalTest() { + System.out.println("MLContextTest - basic eval test"); setExpectedStdOut("10"); Script script = dmlFromFile(baseDirectory + File.separator + "eval-test.dml"); ml.execute(script); } + + @Test + public void testRewriteExecuteEvalTest() { + System.out.println("MLContextTest - eval rewrite test"); + Script script = dmlFromFile(baseDirectory + File.separator + "eval2-test.dml"); + ml.execute(script); + Assert.assertTrue(Statistics.getNoOfExecutedSPInst() == 0); + } @Test public void testCreateDMLScriptBasedOnStringAndExecute() { http://git-wip-us.apache.org/repos/asf/systemml/blob/108ee7ae/src/test/scripts/org/apache/sysml/api/mlcontext/eval2-test.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/org/apache/sysml/api/mlcontext/eval2-test.dml b/src/test/scripts/org/apache/sysml/api/mlcontext/eval2-test.dml new file mode 100644 index 0000000..0a50b7f --- /dev/null +++ b/src/test/scripts/org/apache/sysml/api/mlcontext/eval2-test.dml @@ -0,0 +1,37 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +f1 = function (matrix[double] A, double b) return (matrix[double] R) { + R = A * b; +} + +f2 = function (matrix[double] A, double b) return (matrix[double] R) { + R = A + b; +} + +# some variables +X = rand(rows=100, cols=10) + +R = matrix(0,0,0); +for(i in 1:2) + R = eval("f"+i, X, 7) + 7; + +print(sum(R));
