Repository: systemml Updated Branches: refs/heads/master 87bc3584d -> 0177a1310
[SYSTEMML-2373] Fix codegen integration nary min/max (costs, single op) This patch fixes the codegen support for nary min/max by (1) including nary ops into the cost model and (2) ensuring that single nary ops are not eagerly prune out before optimization. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/37e66039 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/37e66039 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/37e66039 Branch: refs/heads/master Commit: 37e66039da49f79ca73489683bc9b02a339baf0f Parents: 87bc358 Author: Matthias Boehm <[email protected]> Authored: Fri Jun 8 20:10:36 2018 -0700 Committer: Matthias Boehm <[email protected]> Committed: Fri Jun 8 20:10:36 2018 -0700 ---------------------------------------------------------------------- .../sysml/hops/codegen/SpoofCompiler.java | 2 +- .../sysml/hops/codegen/opt/PlanSelection.java | 2 +- .../opt/PlanSelectionFuseCostBasedV2.java | 5 +++ .../hops/codegen/template/CPlanMemoTable.java | 2 +- .../hops/codegen/template/TemplateUtils.java | 7 +++++ .../functions/codegen/CellwiseTmplTest.java | 26 ++++++++++++--- .../scripts/functions/codegen/cellwisetmpl24.R | 31 ++++++++++++++++++ .../functions/codegen/cellwisetmpl24.dml | 33 ++++++++++++++++++++ 8 files changed, 100 insertions(+), 8 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/37e66039/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java b/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java index 6a23c8d..fd012ec 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java +++ b/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java @@ -749,7 +749,7 @@ public class SpoofCompiler || ((CNodeRow)tpl).getRowType()==RowType.ROW_AGG ) && TemplateUtils.hasSingleOperation(tpl)) || TemplateUtils.hasNoOperation(tpl) ) - { + { cplans2.remove(e.getKey()); if( LOG.isTraceEnabled() ) LOG.trace("Removed cplan with single operation."); http://git-wip-us.apache.org/repos/asf/systemml/blob/37e66039/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelection.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelection.java b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelection.java index 5242211..c676618 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelection.java +++ b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelection.java @@ -70,7 +70,7 @@ public abstract class PlanSelection } protected void rSelectPlansFuseAll(CPlanMemoTable memo, Hop current, TemplateType currentType, HashSet<Long> partition) - { + { if( isVisited(current.getHopID(), currentType) || (partition!=null && !partition.contains(current.getHopID())) ) return; http://git-wip-us.apache.org/repos/asf/systemml/blob/37e66039/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java index 3db4ce8..e96f3a7 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java +++ b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java @@ -51,6 +51,7 @@ import org.apache.sysml.hops.Hop.OpOp2; import org.apache.sysml.hops.Hop.OpOpN; import org.apache.sysml.hops.IndexingOp; import org.apache.sysml.hops.LiteralOp; +import org.apache.sysml.hops.NaryOp; import org.apache.sysml.hops.OptimizerUtils; import org.apache.sysml.hops.ParameterizedBuiltinOp; import org.apache.sysml.hops.ReorgOp; @@ -1118,6 +1119,10 @@ public class PlanSelectionFuseCostBasedV2 extends PlanSelection + "implemented yet for: "+((TernaryOp)current).getOp()); } } + else if( current instanceof NaryOp ) { + costs = HopRewriteUtils.isNary(current, OpOpN.MIN, OpOpN.MAX) ? + current.getInput().size() : 1; + } else if( current instanceof ParameterizedBuiltinOp ) { costs = 1; } http://git-wip-us.apache.org/repos/asf/systemml/blob/37e66039/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java b/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java index 5c90ca0..7e18dbe 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java +++ b/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java @@ -244,7 +244,7 @@ public class CPlanMemoTable Iterator<Entry<Long, List<MemoTableEntry>>> iter = _plans.entrySet().iterator(); while( iter.hasNext() ) { Entry<Long, List<MemoTableEntry>> e = iter.next(); - if( !ix.contains(e.getKey()) ) { + if( !(ix.contains(e.getKey()) || TemplateUtils.isValidSingleOperation(_hopRefs.get(e.getKey()))) ) { e.getValue().removeIf(p -> !p.hasPlanRef()); if( e.getValue().isEmpty() ) iter.remove(); http://git-wip-us.apache.org/repos/asf/systemml/blob/37e66039/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java index 3a0b1ed..232b214 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java +++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java @@ -36,6 +36,8 @@ import org.apache.sysml.hops.ParameterizedBuiltinOp; import org.apache.sysml.hops.TernaryOp; import org.apache.sysml.hops.Hop.AggOp; import org.apache.sysml.hops.Hop.Direction; +import org.apache.sysml.hops.Hop.OpOp1; +import org.apache.sysml.hops.Hop.OpOpN; import org.apache.sysml.hops.IndexingOp; import org.apache.sysml.hops.UnaryOp; import org.apache.sysml.hops.codegen.SpoofCompiler; @@ -343,6 +345,11 @@ public class TemplateUtils && hasOnlyDataNodeOrLookupInputs(output); } + public static boolean isValidSingleOperation(Hop hop) { + return HopRewriteUtils.isNary(hop, OpOpN.MIN, OpOpN.MAX) + || HopRewriteUtils.isUnary(hop, OpOp1.EXP, OpOp1.LOG); + } + public static boolean hasNoOperation(CNodeTpl tpl) { return tpl.getOutput() instanceof CNodeData || isLookup(tpl.getOutput(), true); http://git-wip-us.apache.org/repos/asf/systemml/blob/37e66039/src/test/java/org/apache/sysml/test/integration/functions/codegen/CellwiseTmplTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/codegen/CellwiseTmplTest.java b/src/test/java/org/apache/sysml/test/integration/functions/codegen/CellwiseTmplTest.java index cfcabde..90a6dc8 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/codegen/CellwiseTmplTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/codegen/CellwiseTmplTest.java @@ -59,7 +59,8 @@ public class CellwiseTmplTest extends AutomatedTestBase private static final String TEST_NAME21 = TEST_NAME+21; //relu operation, (X>0)*dout private static final String TEST_NAME22 = TEST_NAME+22; //sum(X * seq(1,N) + t(seq(M,1))) private static final String TEST_NAME23 = TEST_NAME+23; //sum(min(X,Y,Z)) - + private static final String TEST_NAME24 = TEST_NAME+24; //min(X, Y, Z, 3, 7) + private static final String TEST_DIR = "functions/codegen/"; private static final String TEST_CLASS_DIR = TEST_DIR + CellwiseTmplTest.class.getSimpleName() + "/"; private final static String TEST_CONF6 = "SystemML-config-codegen6.xml"; @@ -71,12 +72,12 @@ public class CellwiseTmplTest extends AutomatedTestBase @Override public void setUp() { TestUtils.clearAssertionInformation(); - for( int i=1; i<=23; i++ ) { + for( int i=1; i<=24; i++ ) { addTestConfiguration( TEST_NAME+i, new TestConfiguration( - TEST_CLASS_DIR, TEST_NAME+i, new String[] {String.valueOf(i)}) ); + TEST_CLASS_DIR, TEST_NAME+i, new String[] {String.valueOf(i)}) ); } } - + @Test public void testCodegenCellwiseRewrite1() { testCodegenIntegration( TEST_NAME1, true, ExecType.CP ); @@ -398,6 +399,21 @@ public class CellwiseTmplTest extends AutomatedTestBase public void testCodegenCellwiseRewrite23_sp() { testCodegenIntegration( TEST_NAME23, true, ExecType.SPARK ); } + + @Test + public void testCodegenCellwiseRewrite24() { + testCodegenIntegration( TEST_NAME24, true, ExecType.CP ); + } + + @Test + public void testCodegenCellwise24() { + testCodegenIntegration( TEST_NAME24, false, ExecType.CP ); + } + + @Test + public void testCodegenCellwiseRewrite24_sp() { + testCodegenIntegration( TEST_NAME24, true, ExecType.SPARK ); + } private void testCodegenIntegration( String testname, boolean rewrites, ExecType instType ) { @@ -467,7 +483,7 @@ public class CellwiseTmplTest extends AutomatedTestBase Assert.assertTrue(!heavyHittersContainsSubString("xor")); else if( testname.equals(TEST_NAME22) ) Assert.assertTrue(!heavyHittersContainsSubString("seq")); - else if( testname.equals(TEST_NAME23) ) + else if( testname.equals(TEST_NAME23) || testname.equals(TEST_NAME24) ) Assert.assertTrue(!heavyHittersContainsSubString("min","nmin")); } finally { http://git-wip-us.apache.org/repos/asf/systemml/blob/37e66039/src/test/scripts/functions/codegen/cellwisetmpl24.R ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/codegen/cellwisetmpl24.R b/src/test/scripts/functions/codegen/cellwisetmpl24.R new file mode 100644 index 0000000..a01a2a4 --- /dev/null +++ b/src/test/scripts/functions/codegen/cellwisetmpl24.R @@ -0,0 +1,31 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +args<-commandArgs(TRUE) +options(digits=22) +library("Matrix") + +X = matrix(6, 500, 2); +Y = matrix(7, 500, 2); +Z = matrix(8, 500, 2); +R = as.matrix(sum(pmin(X,Y,Z,3,7))); + +writeMM(as(R,"CsparseMatrix"), paste(args[2], "S", sep="")); http://git-wip-us.apache.org/repos/asf/systemml/blob/37e66039/src/test/scripts/functions/codegen/cellwisetmpl24.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/codegen/cellwisetmpl24.dml b/src/test/scripts/functions/codegen/cellwisetmpl24.dml new file mode 100644 index 0000000..75e3564 --- /dev/null +++ b/src/test/scripts/functions/codegen/cellwisetmpl24.dml @@ -0,0 +1,33 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X = matrix(6, 500, 2); +Y = matrix(7, 500, 2); +Z = matrix(8, 500, 2); + +while(FALSE){} + +R = min(X,Y,Z,3,7); + +while(FALSE){} + +R = as.matrix(sum(R)); +write(R, $1)
